From 403072d991bce8ddac46b07a3d8f9e2fdc0c31b9 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Tue, 2 Apr 2024 16:42:10 +0100 Subject: [PATCH 01/11] Add aggregate.hybrid method using GQL --- integration/test_collection_aggregate.py | 47 ++++++++ weaviate/collections/aggregate.py | 5 +- weaviate/collections/aggregations/base.py | 34 +++++- weaviate/collections/aggregations/hybrid.py | 109 ++++++++++++++++++ .../collections/aggregations/near_image.py | 2 +- .../collections/aggregations/near_object.py | 2 +- .../collections/aggregations/near_text.py | 2 +- .../collections/aggregations/near_vector.py | 2 +- .../collections/queries/hybrid/generate.py | 2 +- .../collections/queries/hybrid/generate.pyi | 12 +- weaviate/collections/queries/hybrid/query.py | 2 +- weaviate/collections/queries/hybrid/query.pyi | 12 +- weaviate/gql/aggregate.py | 74 +++++++++++- 13 files changed, 276 insertions(+), 29 deletions(-) create mode 100644 weaviate/collections/aggregations/hybrid.py diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 6733321e3..ec53683cb 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -290,6 +290,53 @@ def test_near_object_missing_param(collection_factory: CollectionFactory) -> Non ) +@pytest.mark.parametrize( + "option,expected_len", + [ + ({"object_limit": 1}, 1), + ({"object_limit": 2}, 2), + ], +) +def test_hybrid_aggregation( + collection_factory: CollectionFactory, option: dict, expected_len: int +) -> None: + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + uuid = collection.data.insert({"text": text_1}) + obj = collection.query.fetch_object_by_id(uuid, include_vector=True) + assert "default" in obj.vector + collection.data.insert({"text": text_2}) + res: AggregateReturn = collection.aggregate.hybrid( + None, + alpha=1, + vector=obj.vector["default"], + return_metrics=[ + Metrics("text").text(count=True, top_occurrences_count=True, top_occurrences_value=True) + ], + **option, + ) + assert isinstance(res.properties["text"], AggregateText) + assert res.properties["text"].count == expected_len + assert len(res.properties["text"].top_occurrences) == expected_len + assert text_1 in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + if expected_len == 2: + assert text_2 in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + else: + assert text_2 not in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + + @pytest.mark.parametrize( "option,expected_len", [ diff --git a/weaviate/collections/aggregate.py b/weaviate/collections/aggregate.py index 4218a9772..5d0e09004 100644 --- a/weaviate/collections/aggregate.py +++ b/weaviate/collections/aggregate.py @@ -1,9 +1,10 @@ -from weaviate.collections.aggregations.over_all import _OverAll +from weaviate.collections.aggregations.hybrid import _Hybrid from weaviate.collections.aggregations.near_image import _NearImage from weaviate.collections.aggregations.near_object import _NearObject from weaviate.collections.aggregations.near_text import _NearText from weaviate.collections.aggregations.near_vector import _NearVector +from weaviate.collections.aggregations.over_all import _OverAll -class _AggregateCollection(_OverAll, _NearImage, _NearObject, _NearText, _NearVector): +class _AggregateCollection(_Hybrid, _NearImage, _NearObject, _NearText, _NearVector, _OverAll): pass diff --git a/weaviate/collections/aggregations/base.py b/weaviate/collections/aggregations/base.py index 8fdc3fbd1..b0547c33b 100644 --- a/weaviate/collections/aggregations/base.py +++ b/weaviate/collections/aggregations/base.py @@ -235,7 +235,33 @@ def _parse_near_options( ) @staticmethod - def _add_near_image( + def _add_hybrid_to_builder( + builder: AggregateBuilder, + query: Optional[str], + alpha: Optional[NUMBER], + vector: Optional[List[float]], + query_properties: Optional[List[str]], + object_limit: Optional[int], + target_vector: Optional[str], + ) -> AggregateBuilder: + payload: dict = {} + if query is not None: + payload["query"] = query + if alpha is not None: + payload["alpha"] = alpha + if vector is not None: + payload["vector"] = vector + if query_properties is not None: + payload["properties"] = query_properties + if target_vector is not None: + payload["targetVectors"] = [target_vector] + builder = builder.with_hybrid(payload) + if object_limit is not None: + builder = builder.with_object_limit(object_limit) + return builder + + @staticmethod + def _add_near_image_to_builder( builder: AggregateBuilder, near_image: Union[str, pathlib.Path, io.BufferedReader], certainty: Optional[NUMBER], @@ -265,7 +291,7 @@ def _add_near_image( return builder @staticmethod - def _add_near_object( + def _add_near_object_to_builder( builder: AggregateBuilder, near_object: UUID, certainty: Optional[NUMBER], @@ -293,7 +319,7 @@ def _add_near_object( return builder @staticmethod - def _add_near_text( + def _add_near_text_to_builder( builder: AggregateBuilder, query: Union[List[str], str], certainty: Optional[NUMBER], @@ -334,7 +360,7 @@ def _add_near_text( return builder @staticmethod - def _add_near_vector( + def _add_near_vector_to_builder( builder: AggregateBuilder, near_vector: List[float], certainty: Optional[NUMBER], diff --git a/weaviate/collections/aggregations/hybrid.py b/weaviate/collections/aggregations/hybrid.py new file mode 100644 index 000000000..961f34ce4 --- /dev/null +++ b/weaviate/collections/aggregations/hybrid.py @@ -0,0 +1,109 @@ +from typing import List, Literal, Optional, Union, overload + +from weaviate.collections.aggregations.base import _Aggregate +from weaviate.collections.classes.aggregate import ( + PropertiesMetrics, + AggregateReturn, + AggregateGroupByReturn, + GroupByAggregate, +) +from weaviate.collections.classes.filters import _Filters +from weaviate.types import NUMBER + + +class _Hybrid(_Aggregate): + @overload + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Literal[None] = None, + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> AggregateReturn: + ... + + @overload + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Union[str, GroupByAggregate], + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> AggregateGroupByReturn: + ... + + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Optional[Union[str, GroupByAggregate]] = None, + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> Union[AggregateReturn, AggregateGroupByReturn]: + """Aggregate metrics over all the objects in this collection without any vector search. + + Arguments: + `query` + The keyword-based query to search for, REQUIRED. If query and vector are both None, a normal search will be performed. + `alpha` + The weight of the BM25 score. If not specified, the default weight specified by the server is used. + `vector` + The specific vector to search for. If not specified, the query is vectorized and used in the similarity search. + `query_properties` + The properties to search in. If not specified, all properties are searched. + `object_limit` + The maximum number of objects to return from the hybrid vector search prior to the aggregation. + `filters` + The filters to apply to the search. + `group_by` + How to group the aggregation by. + `total_count` + Whether to include the total number of objects that match the query in the response. + `return_metrics` + A list of property metrics to aggregate together after the text search. + + Returns: + Depending on the presence of the `group_by` argument, either a `AggregateReturn` object or a `AggregateGroupByReturn that includes the aggregation objects. + + Raises: + `weaviate.exceptions.WeaviateQueryError`: + If an error occurs while performing the query against Weaviate. + `weaviate.exceptions.WeaviateInvalidInputError`: + If any of the input arguments are of the wrong type. + """ + return_metrics = ( + return_metrics + if (return_metrics is None or isinstance(return_metrics, list)) + else [return_metrics] + ) + builder = self._base(return_metrics, filters, total_count) + builder = self._add_hybrid_to_builder( + builder, query, alpha, vector, query_properties, object_limit, target_vector + ) + builder = self._add_groupby_to_builder(builder, group_by) + res = self._do(builder) + return ( + self._to_aggregate_result(res, return_metrics) + if group_by is None + else self._to_group_by_result(res, return_metrics) + ) diff --git a/weaviate/collections/aggregations/near_image.py b/weaviate/collections/aggregations/near_image.py index f009aba33..591ff9e28 100644 --- a/weaviate/collections/aggregations/near_image.py +++ b/weaviate/collections/aggregations/near_image.py @@ -99,7 +99,7 @@ def near_image( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_image( + builder = self._add_near_image_to_builder( builder, near_image, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/aggregations/near_object.py b/weaviate/collections/aggregations/near_object.py index df3d1f6fd..c1f18d3c0 100644 --- a/weaviate/collections/aggregations/near_object.py +++ b/weaviate/collections/aggregations/near_object.py @@ -97,7 +97,7 @@ def near_object( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_object( + builder = self._add_near_object_to_builder( builder, near_object, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/aggregations/near_text.py b/weaviate/collections/aggregations/near_text.py index 9a8b41ad1..bb17f6d23 100644 --- a/weaviate/collections/aggregations/near_text.py +++ b/weaviate/collections/aggregations/near_text.py @@ -108,7 +108,7 @@ def near_text( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_text( + builder = self._add_near_text_to_builder( builder=builder, query=query, certainty=certainty, diff --git a/weaviate/collections/aggregations/near_vector.py b/weaviate/collections/aggregations/near_vector.py index a8346576a..0b6d02bb9 100644 --- a/weaviate/collections/aggregations/near_vector.py +++ b/weaviate/collections/aggregations/near_vector.py @@ -97,7 +97,7 @@ def near_vector( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_vector( + builder = self._add_near_vector_to_builder( builder, near_vector, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/queries/hybrid/generate.py b/weaviate/collections/queries/hybrid/generate.py index cad94fc0b..d647c5a32 100644 --- a/weaviate/collections/queries/hybrid/generate.py +++ b/weaviate/collections/queries/hybrid/generate.py @@ -24,7 +24,7 @@ def hybrid( single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/generate.pyi b/weaviate/collections/queries/hybrid/generate.pyi index f803ada7e..56a19fa7c 100644 --- a/weaviate/collections/queries/hybrid/generate.pyi +++ b/weaviate/collections/queries/hybrid/generate.pyi @@ -28,7 +28,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -51,7 +51,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -74,7 +74,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -97,7 +97,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -120,7 +120,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -143,7 +143,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/query.py b/weaviate/collections/queries/hybrid/query.py index c8961ca3b..c8caa12a2 100644 --- a/weaviate/collections/queries/hybrid/query.py +++ b/weaviate/collections/queries/hybrid/query.py @@ -20,7 +20,7 @@ def hybrid( self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/query.pyi b/weaviate/collections/queries/hybrid/query.pyi index 141312ea4..1b6013cd2 100644 --- a/weaviate/collections/queries/hybrid/query.pyi +++ b/weaviate/collections/queries/hybrid/query.pyi @@ -25,7 +25,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -45,7 +45,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -65,7 +65,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -85,7 +85,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -105,7 +105,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -125,7 +125,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/gql/aggregate.py b/weaviate/gql/aggregate.py index 71f6198c4..9d311169f 100644 --- a/weaviate/gql/aggregate.py +++ b/weaviate/gql/aggregate.py @@ -3,13 +3,11 @@ """ import json +from dataclasses import dataclass from typing import List, Optional, Union from weaviate.connect import Connection, ConnectionV4 -from weaviate.util import ( - _capitalize_first_letter, - file_encoder_b64, -) +from weaviate.util import _capitalize_first_letter, file_encoder_b64, _sanitize_str from .filter import ( Where, GraphQL, @@ -27,6 +25,38 @@ ) +@dataclass +class Hybrid: + query: Optional[str] + alpha: Optional[float] + vector: Optional[List[float]] + properties: Optional[List[str]] + target_vectors: Optional[List[str]] + + def __init__(self, content: dict) -> None: + self.query = content.get("query") + self.alpha = content.get("alpha") + self.vector = content.get("vector") + self.properties = content.get("properties") + self.target_vectors = content.get("targetVectors") + + def __str__(self) -> str: + ret = "" + if self.query is not None: + ret += f"query: {_sanitize_str(self.query)}" + if self.vector is not None: + ret += f", vector: {self.vector}" + if self.alpha is not None: + ret += f", alpha: {self.alpha}" + if self.properties is not None and len(self.properties) > 0: + props = '","'.join(self.properties) + ret += f', properties: ["{props}"]' + if self.target_vectors is not None: + target_vectors = '","'.join(self.target_vectors) + ret += f', targetVectors: ["{target_vectors}"]' + return "hybrid:{" + ret + "}" + + class AggregateBuilder(GraphQL): """ AggregateBuilder class used to aggregate Weaviate objects. @@ -55,6 +85,7 @@ def __init__(self, class_name: str, connection: Union[Connection, ConnectionV4]) self._near: Optional[Filter] = None self._tenant: Optional[str] = None self._limit: Optional[int] = None + self._hybrid: Optional[Hybrid] = None def with_tenant(self, tenant: str) -> "AggregateBuilder": """Sets a tenant for the query.""" @@ -209,6 +240,20 @@ def with_where(self, content: dict) -> "AggregateBuilder": self._uses_filter = True return self + def with_hybrid(self, content: dict) -> "AggregateBuilder": + """Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm. + + Parameters + ---------- + content : dict + The content of the `hybrid` filter to set. + """ + if self._near is not None: + raise AttributeError("Cannot use 'hybrid' and 'near' filters simultaneously.") + self._hybrid = Hybrid(content) + self._uses_filter = True + return self + def with_group_by_filter(self, properties: List[str]) -> "AggregateBuilder": """ Add a group by filter to the query. Might requires the user to set @@ -308,6 +353,8 @@ def with_near_text(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearText(content) self._uses_filter = True return self @@ -373,6 +420,8 @@ def with_near_vector(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearVector(content) self._uses_filter = True return self @@ -423,6 +472,8 @@ def with_near_object(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearObject(content, is_server_version_14) self._uses_filter = True return self @@ -534,6 +585,8 @@ def with_near_image(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content["image"] = file_encoder_b64(content["image"]) self._near = NearImage(content) @@ -648,6 +701,8 @@ def with_near_audio(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearAudio(content) @@ -762,6 +817,8 @@ def with_near_video(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearVideo(content) @@ -876,6 +933,8 @@ def with_near_depth(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearDepth(content) @@ -989,6 +1048,8 @@ def with_near_thermal(self, content: dict, encode: bool = True) -> "AggregateBui "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearThermal(content) @@ -1103,6 +1164,8 @@ def with_near_imu(self, content: dict, encode: bool = True) -> "AggregateBuilder "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearIMU(content) @@ -1137,7 +1200,8 @@ def build(self) -> str: query += f'tenant: "{self._tenant}"' if self._limit is not None: query += f"limit: {self._limit}" - + if self._hybrid is not None: + query += f"hybrid: {self._hybrid}" query += ")" # Body From 8f79b2c686d28643a11ce87432d97515a31c5847 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Tue, 2 Apr 2024 17:17:41 +0100 Subject: [PATCH 02/11] Fix hybrid query in AggregateBuilder --- weaviate/gql/aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/gql/aggregate.py b/weaviate/gql/aggregate.py index 9d311169f..445e46b99 100644 --- a/weaviate/gql/aggregate.py +++ b/weaviate/gql/aggregate.py @@ -1201,7 +1201,7 @@ def build(self) -> str: if self._limit is not None: query += f"limit: {self._limit}" if self._hybrid is not None: - query += f"hybrid: {self._hybrid}" + query += str(self._hybrid) query += ")" # Body From f3c65183b29eb01c741f77cf3138760875a8f3e3 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 12:50:25 +0300 Subject: [PATCH 03/11] Fix wrong docstring --- weaviate/collections/aggregations/hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/aggregations/hybrid.py b/weaviate/collections/aggregations/hybrid.py index 961f34ce4..07b6d194f 100644 --- a/weaviate/collections/aggregations/hybrid.py +++ b/weaviate/collections/aggregations/hybrid.py @@ -60,7 +60,7 @@ def hybrid( total_count: bool = True, return_metrics: Optional[PropertiesMetrics] = None, ) -> Union[AggregateReturn, AggregateGroupByReturn]: - """Aggregate metrics over all the objects in this collection without any vector search. + """Aggregate metrics over all the objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. Arguments: `query` From edc945fb230ec03c6fe5dd0734c5165b9272fa48 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 13:07:14 +0300 Subject: [PATCH 04/11] Add another test for additional aggregate.hybrid params --- integration/test_collection_aggregate.py | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index ec53683cb..1a9e4fcec 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -1,6 +1,7 @@ import pathlib import uuid from datetime import datetime, timezone +from typing import Union import pytest from _pytest.fixtures import SubRequest @@ -337,6 +338,37 @@ def test_hybrid_aggregation( ] +@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)]) +def test_hybrid_aggregation_group_by( + collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] +) -> None: + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=[ + Configure.NamedVectors.text2vec_contextionary( + name="all", vectorize_collection_name=False + ) + ], + ) + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + collection.data.insert({"text": text_1}) + collection.data.insert({"text": text_2}) + + res = collection.aggregate.hybrid( + "text", + alpha=0, + query_properties=["text"], + group_by=group_by, + total_count=True, + object_limit=2, # has no effect due to alpha=0 + target_vector="all", + ) + assert res.groups[0].grouped_by.prop == "text" + assert res.groups[0].grouped_by.value == "some text" + assert res.groups[0].total_count == 1 + + @pytest.mark.parametrize( "option,expected_len", [ From 6ec483fb14eb3ef3d4db6fe8eb28fe8182beb157 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 14:18:55 +0300 Subject: [PATCH 05/11] Fix tests for named vectors with <1.24 --- integration/test_collection_aggregate.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 1a9e4fcec..b67c76c96 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -342,7 +342,8 @@ def test_hybrid_aggregation( def test_hybrid_aggregation_group_by( collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] ) -> None: - collection = collection_factory( + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( properties=[Property(name="text", data_type=DataType.TEXT)], vectorizer_config=[ Configure.NamedVectors.text2vec_contextionary( @@ -350,6 +351,12 @@ def test_hybrid_aggregation_group_by( ) ], ) + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() text_1 = "some text" text_2 = "nothing like the other one at all, not even a little bit" collection.data.insert({"text": text_1}) @@ -383,12 +390,19 @@ def test_hybrid_aggregation_group_by( def test_near_vector_aggregation( collection_factory: CollectionFactory, option: dict, expected_len: int ) -> None: - collection = collection_factory( + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( properties=[Property(name="text", data_type=DataType.TEXT)], vectorizer_config=Configure.Vectorizer.text2vec_contextionary( vectorize_collection_name=False ), ) + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() text_1 = "some text" text_2 = "nothing like the other one at all, not even a little bit" uuid = collection.data.insert({"text": text_1}) From 4dd930bdf27c99de3e00ae78c2fcf97e229cf93d Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 14:31:19 +0300 Subject: [PATCH 06/11] Fix incorrect skip on version --- integration/test_collection_aggregate.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index b67c76c96..450cabd87 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -390,17 +390,12 @@ def test_hybrid_aggregation_group_by( def test_near_vector_aggregation( collection_factory: CollectionFactory, option: dict, expected_len: int ) -> None: - dummy = collection_factory("dummy") collection_maker = lambda: collection_factory( properties=[Property(name="text", data_type=DataType.TEXT)], vectorizer_config=Configure.Vectorizer.text2vec_contextionary( vectorize_collection_name=False ), ) - if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): - with pytest.raises(WeaviateInvalidInputError): - collection_maker() - return collection = collection_maker() text_1 = "some text" From 7916dbbec5c61dee89a283afeca26f79665f0191 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 14:31:35 +0300 Subject: [PATCH 07/11] Fix target vector passing in hybrid query --- weaviate/collections/aggregations/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/aggregations/base.py b/weaviate/collections/aggregations/base.py index b0547c33b..1929c3eac 100644 --- a/weaviate/collections/aggregations/base.py +++ b/weaviate/collections/aggregations/base.py @@ -254,7 +254,7 @@ def _add_hybrid_to_builder( if query_properties is not None: payload["properties"] = query_properties if target_vector is not None: - payload["targetVectors"] = [target_vector] + payload["targetVector"] = target_vector builder = builder.with_hybrid(payload) if object_limit is not None: builder = builder.with_object_limit(object_limit) From 6e9fc9595b426d6083c1988c3ed101dd7493ad95 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 14:46:26 +0300 Subject: [PATCH 08/11] Fix target vector passing in hybrid query --- weaviate/collections/aggregations/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/aggregations/base.py b/weaviate/collections/aggregations/base.py index 1929c3eac..b0547c33b 100644 --- a/weaviate/collections/aggregations/base.py +++ b/weaviate/collections/aggregations/base.py @@ -254,7 +254,7 @@ def _add_hybrid_to_builder( if query_properties is not None: payload["properties"] = query_properties if target_vector is not None: - payload["targetVector"] = target_vector + payload["targetVectors"] = [target_vector] builder = builder.with_hybrid(payload) if object_limit is not None: builder = builder.with_object_limit(object_limit) From b9b2a33062160ad05c5abdfed12802487627e3f4 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 15:05:29 +0300 Subject: [PATCH 09/11] Skip test on broken Weaviate `1.24.x` versions --- integration/test_collection_aggregate.py | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 450cabd87..2265e2282 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -341,6 +341,41 @@ def test_hybrid_aggregation( @pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)]) def test_hybrid_aggregation_group_by( collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] +) -> None: + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + collection.data.insert({"text": text_1}) + collection.data.insert({"text": text_2}) + + res = collection.aggregate.hybrid( + "text", + alpha=0, + query_properties=["text"], + group_by=group_by, + total_count=True, + object_limit=2, # has no effect due to alpha=0 + ) + assert res.groups[0].grouped_by.prop == "text" + assert res.groups[0].grouped_by.value == "some text" + assert res.groups[0].total_count == 1 + + +@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)]) +def test_hybrid_aggregation_group_by_with_named_vectors( + collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] ) -> None: dummy = collection_factory("dummy") collection_maker = lambda: collection_factory( @@ -355,6 +390,10 @@ def test_hybrid_aggregation_group_by( with pytest.raises(WeaviateInvalidInputError): collection_maker() return + if dummy._connection._weaviate_version.is_lower_than( + 1, 24, 11 + ) and dummy._connection._weaviate_version.is_at_least(1, 24, 0): + pytest.skip("Currently bugged with 1.24.x <= 1.24.10") collection = collection_maker() text_1 = "some text" From 2f1185dafdd07447d8c748340e33ff33a35b1ae5 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 15:17:14 +0300 Subject: [PATCH 10/11] Fix bad exception catch in non named vector test --- integration/test_collection_aggregate.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 2265e2282..beb7b4687 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -342,19 +342,13 @@ def test_hybrid_aggregation( def test_hybrid_aggregation_group_by( collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] ) -> None: - dummy = collection_factory("dummy") - collection_maker = lambda: collection_factory( + collection = collection_factory( properties=[Property(name="text", data_type=DataType.TEXT)], vectorizer_config=Configure.Vectorizer.text2vec_contextionary( vectorize_collection_name=False ), ) - if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): - with pytest.raises(WeaviateInvalidInputError): - collection_maker() - return - collection = collection_maker() text_1 = "some text" text_2 = "nothing like the other one at all, not even a little bit" collection.data.insert({"text": text_1}) From e19cde818e92414c489625384e417cdadf1456b5 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 3 May 2024 15:38:39 +0300 Subject: [PATCH 11/11] Catch usage of hybrid aggregation on unsupported weaviate --- integration/test_collection_aggregate.py | 26 +++++++++++++++------ weaviate/collections/aggregations/base.py | 4 ++-- weaviate/collections/aggregations/hybrid.py | 5 ++++ 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index beb7b4687..2d7954633 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -19,7 +19,11 @@ ) from weaviate.collections.classes.config import DataType, Property, ReferenceProperty, Configure from weaviate.collections.classes.filters import Filter, _Filters -from weaviate.exceptions import WeaviateInvalidInputError, WeaviateQueryError +from weaviate.exceptions import ( + WeaviateInvalidInputError, + WeaviateQueryError, + WeaviateNotImplementedError, +) from weaviate.util import file_encoder_b64 from weaviate.collections.classes.grpc import Move @@ -354,7 +358,7 @@ def test_hybrid_aggregation_group_by( collection.data.insert({"text": text_1}) collection.data.insert({"text": text_2}) - res = collection.aggregate.hybrid( + querier = lambda: collection.aggregate.hybrid( "text", alpha=0, query_properties=["text"], @@ -362,6 +366,12 @@ def test_hybrid_aggregation_group_by( total_count=True, object_limit=2, # has no effect due to alpha=0 ) + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateNotImplementedError): + querier() + return + + res = querier() assert res.groups[0].grouped_by.prop == "text" assert res.groups[0].grouped_by.value == "some text" assert res.groups[0].total_count == 1 @@ -384,10 +394,6 @@ def test_hybrid_aggregation_group_by_with_named_vectors( with pytest.raises(WeaviateInvalidInputError): collection_maker() return - if dummy._connection._weaviate_version.is_lower_than( - 1, 24, 11 - ) and dummy._connection._weaviate_version.is_at_least(1, 24, 0): - pytest.skip("Currently bugged with 1.24.x <= 1.24.10") collection = collection_maker() text_1 = "some text" @@ -395,7 +401,7 @@ def test_hybrid_aggregation_group_by_with_named_vectors( collection.data.insert({"text": text_1}) collection.data.insert({"text": text_2}) - res = collection.aggregate.hybrid( + querier = lambda: collection.aggregate.hybrid( "text", alpha=0, query_properties=["text"], @@ -404,6 +410,12 @@ def test_hybrid_aggregation_group_by_with_named_vectors( object_limit=2, # has no effect due to alpha=0 target_vector="all", ) + if dummy._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateNotImplementedError): + querier() + return + + res = querier() assert res.groups[0].grouped_by.prop == "text" assert res.groups[0].grouped_by.value == "some text" assert res.groups[0].total_count == 1 diff --git a/weaviate/collections/aggregations/base.py b/weaviate/collections/aggregations/base.py index b0547c33b..f512fdfac 100644 --- a/weaviate/collections/aggregations/base.py +++ b/weaviate/collections/aggregations/base.py @@ -52,13 +52,13 @@ def __init__( consistency_level: Optional[ConsistencyLevel], tenant: Optional[str], ): - self.__connection = connection + self._connection = connection self.__name = name self._tenant = tenant self._consistency_level = consistency_level def _query(self) -> AggregateBuilder: - return AggregateBuilder(self.__name, self.__connection) + return AggregateBuilder(self.__name, self._connection) def _to_aggregate_result( self, response: dict, metrics: Optional[List[_Metrics]] diff --git a/weaviate/collections/aggregations/hybrid.py b/weaviate/collections/aggregations/hybrid.py index 07b6d194f..41e0c529f 100644 --- a/weaviate/collections/aggregations/hybrid.py +++ b/weaviate/collections/aggregations/hybrid.py @@ -8,6 +8,7 @@ GroupByAggregate, ) from weaviate.collections.classes.filters import _Filters +from weaviate.exceptions import WeaviateNotImplementedError from weaviate.types import NUMBER @@ -91,6 +92,10 @@ def hybrid( `weaviate.exceptions.WeaviateInvalidInputError`: If any of the input arguments are of the wrong type. """ + if group_by is not None and self._connection._weaviate_version.is_lower_than(1, 25, 0): + raise WeaviateNotImplementedError( + "Hybrid aggregation", self._connection.server_version, "1.25.0" + ) return_metrics = ( return_metrics if (return_metrics is None or isinstance(return_metrics, list))