From 63fd20d0f16729296045ea1c2043959ab1cc4667 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 3 Jan 2024 19:20:09 +0400 Subject: [PATCH] Add a `knn` method to `elasticsearch_dsl.search.Search` (#1691) (#1693) * Add a `knn` method to `elasticsearch_dsl.search.Search` * add knn's boost option (cherry picked from commit baed085cd601468f1db014b5468c1601557c5469) Co-authored-by: Miguel Grinberg --- docs/search_dsl.rst | 27 ++++++++++++++ elasticsearch_dsl/search.py | 72 ++++++++++++++++++++++++++++++++++++- tests/test_search.py | 54 ++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 1 deletion(-) diff --git a/docs/search_dsl.rst b/docs/search_dsl.rst index 5912f8274..2dc1388f4 100644 --- a/docs/search_dsl.rst +++ b/docs/search_dsl.rst @@ -14,6 +14,8 @@ The ``Search`` object represents the entire search request: * aggregations + * k-nearest neighbor searches + * sort * pagination @@ -352,6 +354,31 @@ As opposed to other methods on the ``Search`` objects, defining aggregations is done in-place (does not return a copy). +K-Nearest Neighbor Searches +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To issue a kNN search, use the ``.knn()`` method: + +.. code:: python + + s = Search() + vector = get_embedding("search text") + + s = s.knn( + field="embedding", + k=5, + num_candidates=10, + query_vector=vector + ) + +The ``field``, ``k`` and ``num_candidates`` arguments can be given as +positional or keyword arguments and are required. In addition to these, +``query_vector`` or ``query_vector_builder`` must be given as well. + +The ``.knn()`` method can be invoked multiple times to include multiple kNN +searches in the request. + + Sorting ~~~~~~~ diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 50df5597c..13b93bd81 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -24,7 +24,7 @@ from .aggs import A, AggBase from .connections import get_connection from .exceptions import IllegalOperation -from .query import Bool, Q +from .query import Bool, Q, Query from .response import Hit, Response from .utils import AttrDict, DslBase, recursive_to_dict @@ -319,6 +319,7 @@ def __init__(self, **kwargs): self.aggs = AggsProxy(self) self._sort = [] self._collapse = {} + self._knn = [] self._source = None self._highlight = {} self._highlight_opts = {} @@ -406,6 +407,7 @@ def _clone(self): s = super()._clone() s._response_class = self._response_class + s._knn = [knn.copy() for knn in self._knn] s._collapse = self._collapse.copy() s._sort = self._sort[:] s._source = copy.copy(self._source) if self._source is not None else None @@ -445,6 +447,10 @@ def update_from_dict(self, d): self.aggs._params = { "aggs": {name: A(value) for (name, value) in aggs.items()} } + if "knn" in d: + self._knn = d.pop("knn") + if isinstance(self._knn, dict): + self._knn = [self._knn] if "collapse" in d: self._collapse = d.pop("collapse") if "sort" in d: @@ -494,6 +500,64 @@ def script_fields(self, **kwargs): s._script_fields.update(kwargs) return s + def knn( + self, + field, + k, + num_candidates, + query_vector=None, + query_vector_builder=None, + boost=None, + filter=None, + similarity=None, + ): + """ + Add a k-nearest neighbor (kNN) search. + + :arg field: the name of the vector field to search against + :arg k: number of nearest neighbors to return as top hits + :arg num_candidates: number of nearest neighbor candidates to consider per shard + :arg query_vector: the vector to search for + :arg query_vector_builder: A dictionary indicating how to build a query vector + :arg boost: A floating-point boost factor for kNN scores + :arg filter: query to filter the documents that can match + :arg similarity: the minimum similarity required for a document to be considered a match, as a float value + + Example:: + + s = Search() + s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector, + filter=Q('term', category='blog'))) + """ + s = self._clone() + s._knn.append( + { + "field": field, + "k": k, + "num_candidates": num_candidates, + } + ) + if query_vector is None and query_vector_builder is None: + raise ValueError("one of query_vector and query_vector_builder is required") + if query_vector is not None and query_vector_builder is not None: + raise ValueError( + "only one of query_vector and query_vector_builder must be given" + ) + if query_vector is not None: + s._knn[-1]["query_vector"] = query_vector + if query_vector_builder is not None: + s._knn[-1]["query_vector_builder"] = query_vector_builder + if boost is not None: + s._knn[-1]["boost"] = boost + if filter is not None: + if isinstance(filter, Query): + s._knn[-1]["filter"] = filter.to_dict() + else: + s._knn[-1]["filter"] = filter + if similarity is not None: + s._knn[-1]["similarity"] = similarity + return s + def source(self, fields=None, **kwargs): """ Selectively control how the _source field is returned. @@ -677,6 +741,12 @@ def to_dict(self, count=False, **kwargs): if self.query: d["query"] = self.query.to_dict() + if self._knn: + if len(self._knn) == 1: + d["knn"] = self._knn[0] + else: + d["knn"] = self._knn + # count request doesn't care for sorting and other things if not count: if self.post_filter: diff --git a/tests/test_search.py b/tests/test_search.py index 5cc84ff84..3b47b8216 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -234,6 +234,60 @@ class MyDocument(Document): assert s._doc_type_map == {} +def test_knn(): + s = search.Search() + + with raises(TypeError): + s.knn() + with raises(TypeError): + s.knn("field") + with raises(TypeError): + s.knn("field", 5) + with raises(ValueError): + s.knn("field", 5, 100) + with raises(ValueError): + s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={}) + + s = s.knn("field", 5, 100, query_vector=[1, 2, 3]) + assert { + "knn": { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + } + } == s.to_dict() + + s = s.knn( + k=4, + num_candidates=40, + boost=0.8, + field="name", + query_vector_builder={ + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + ) + assert { + "knn": [ + { + "field": "field", + "k": 5, + "num_candidates": 100, + "query_vector": [1, 2, 3], + }, + { + "field": "name", + "k": 4, + "num_candidates": 40, + "query_vector_builder": { + "text_embedding": {"model_id": "foo", "model_text": "search text"} + }, + "boost": 0.8, + }, + ] + } == s.to_dict() + + def test_sort(): s = search.Search() s = s.sort("fielda", "-fieldb")