diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 6733321e3..2d7954633 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -1,6 +1,7 @@ import pathlib import uuid from datetime import datetime, timezone +from typing import Union import pytest from _pytest.fixtures import SubRequest @@ -18,7 +19,11 @@ ) from weaviate.collections.classes.config import DataType, Property, ReferenceProperty, Configure from weaviate.collections.classes.filters import Filter, _Filters -from weaviate.exceptions import WeaviateInvalidInputError, WeaviateQueryError +from weaviate.exceptions import ( + WeaviateInvalidInputError, + WeaviateQueryError, + WeaviateNotImplementedError, +) from weaviate.util import file_encoder_b64 from weaviate.collections.classes.grpc import Move @@ -290,6 +295,132 @@ def test_near_object_missing_param(collection_factory: CollectionFactory) -> Non ) +@pytest.mark.parametrize( + "option,expected_len", + [ + ({"object_limit": 1}, 1), + ({"object_limit": 2}, 2), + ], +) +def test_hybrid_aggregation( + collection_factory: CollectionFactory, option: dict, expected_len: int +) -> None: + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + uuid = collection.data.insert({"text": text_1}) + obj = collection.query.fetch_object_by_id(uuid, include_vector=True) + assert "default" in obj.vector + collection.data.insert({"text": text_2}) + res: AggregateReturn = collection.aggregate.hybrid( + None, + alpha=1, + vector=obj.vector["default"], + return_metrics=[ + Metrics("text").text(count=True, top_occurrences_count=True, top_occurrences_value=True) + ], + **option, + ) + assert isinstance(res.properties["text"], AggregateText) + assert res.properties["text"].count == expected_len + assert len(res.properties["text"].top_occurrences) == expected_len + assert text_1 in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + if expected_len == 2: + assert text_2 in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + else: + assert text_2 not in [ + top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences + ] + + +@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)]) +def test_hybrid_aggregation_group_by( + collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] +) -> None: + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + collection.data.insert({"text": text_1}) + collection.data.insert({"text": text_2}) + + querier = lambda: collection.aggregate.hybrid( + "text", + alpha=0, + query_properties=["text"], + group_by=group_by, + total_count=True, + object_limit=2, # has no effect due to alpha=0 + ) + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateNotImplementedError): + querier() + return + + res = querier() + assert res.groups[0].grouped_by.prop == "text" + assert res.groups[0].grouped_by.value == "some text" + assert res.groups[0].total_count == 1 + + +@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)]) +def test_hybrid_aggregation_group_by_with_named_vectors( + collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate] +) -> None: + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=[ + Configure.NamedVectors.text2vec_contextionary( + name="all", vectorize_collection_name=False + ) + ], + ) + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() + text_1 = "some text" + text_2 = "nothing like the other one at all, not even a little bit" + collection.data.insert({"text": text_1}) + collection.data.insert({"text": text_2}) + + querier = lambda: collection.aggregate.hybrid( + "text", + alpha=0, + query_properties=["text"], + group_by=group_by, + total_count=True, + object_limit=2, # has no effect due to alpha=0 + target_vector="all", + ) + if dummy._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateNotImplementedError): + querier() + return + + res = querier() + assert res.groups[0].grouped_by.prop == "text" + assert res.groups[0].grouped_by.value == "some text" + assert res.groups[0].total_count == 1 + + @pytest.mark.parametrize( "option,expected_len", [ @@ -304,12 +435,14 @@ def test_near_object_missing_param(collection_factory: CollectionFactory) -> Non def test_near_vector_aggregation( collection_factory: CollectionFactory, option: dict, expected_len: int ) -> None: - collection = collection_factory( + collection_maker = lambda: collection_factory( properties=[Property(name="text", data_type=DataType.TEXT)], vectorizer_config=Configure.Vectorizer.text2vec_contextionary( vectorize_collection_name=False ), ) + + collection = collection_maker() text_1 = "some text" text_2 = "nothing like the other one at all, not even a little bit" uuid = collection.data.insert({"text": text_1}) diff --git a/weaviate/collections/aggregate.py b/weaviate/collections/aggregate.py index 4218a9772..5d0e09004 100644 --- a/weaviate/collections/aggregate.py +++ b/weaviate/collections/aggregate.py @@ -1,9 +1,10 @@ -from weaviate.collections.aggregations.over_all import _OverAll +from weaviate.collections.aggregations.hybrid import _Hybrid from weaviate.collections.aggregations.near_image import _NearImage from weaviate.collections.aggregations.near_object import _NearObject from weaviate.collections.aggregations.near_text import _NearText from weaviate.collections.aggregations.near_vector import _NearVector +from weaviate.collections.aggregations.over_all import _OverAll -class _AggregateCollection(_OverAll, _NearImage, _NearObject, _NearText, _NearVector): +class _AggregateCollection(_Hybrid, _NearImage, _NearObject, _NearText, _NearVector, _OverAll): pass diff --git a/weaviate/collections/aggregations/base.py b/weaviate/collections/aggregations/base.py index 8fdc3fbd1..f512fdfac 100644 --- a/weaviate/collections/aggregations/base.py +++ b/weaviate/collections/aggregations/base.py @@ -52,13 +52,13 @@ def __init__( consistency_level: Optional[ConsistencyLevel], tenant: Optional[str], ): - self.__connection = connection + self._connection = connection self.__name = name self._tenant = tenant self._consistency_level = consistency_level def _query(self) -> AggregateBuilder: - return AggregateBuilder(self.__name, self.__connection) + return AggregateBuilder(self.__name, self._connection) def _to_aggregate_result( self, response: dict, metrics: Optional[List[_Metrics]] @@ -235,7 +235,33 @@ def _parse_near_options( ) @staticmethod - def _add_near_image( + def _add_hybrid_to_builder( + builder: AggregateBuilder, + query: Optional[str], + alpha: Optional[NUMBER], + vector: Optional[List[float]], + query_properties: Optional[List[str]], + object_limit: Optional[int], + target_vector: Optional[str], + ) -> AggregateBuilder: + payload: dict = {} + if query is not None: + payload["query"] = query + if alpha is not None: + payload["alpha"] = alpha + if vector is not None: + payload["vector"] = vector + if query_properties is not None: + payload["properties"] = query_properties + if target_vector is not None: + payload["targetVectors"] = [target_vector] + builder = builder.with_hybrid(payload) + if object_limit is not None: + builder = builder.with_object_limit(object_limit) + return builder + + @staticmethod + def _add_near_image_to_builder( builder: AggregateBuilder, near_image: Union[str, pathlib.Path, io.BufferedReader], certainty: Optional[NUMBER], @@ -265,7 +291,7 @@ def _add_near_image( return builder @staticmethod - def _add_near_object( + def _add_near_object_to_builder( builder: AggregateBuilder, near_object: UUID, certainty: Optional[NUMBER], @@ -293,7 +319,7 @@ def _add_near_object( return builder @staticmethod - def _add_near_text( + def _add_near_text_to_builder( builder: AggregateBuilder, query: Union[List[str], str], certainty: Optional[NUMBER], @@ -334,7 +360,7 @@ def _add_near_text( return builder @staticmethod - def _add_near_vector( + def _add_near_vector_to_builder( builder: AggregateBuilder, near_vector: List[float], certainty: Optional[NUMBER], diff --git a/weaviate/collections/aggregations/hybrid.py b/weaviate/collections/aggregations/hybrid.py new file mode 100644 index 000000000..41e0c529f --- /dev/null +++ b/weaviate/collections/aggregations/hybrid.py @@ -0,0 +1,114 @@ +from typing import List, Literal, Optional, Union, overload + +from weaviate.collections.aggregations.base import _Aggregate +from weaviate.collections.classes.aggregate import ( + PropertiesMetrics, + AggregateReturn, + AggregateGroupByReturn, + GroupByAggregate, +) +from weaviate.collections.classes.filters import _Filters +from weaviate.exceptions import WeaviateNotImplementedError +from weaviate.types import NUMBER + + +class _Hybrid(_Aggregate): + @overload + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Literal[None] = None, + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> AggregateReturn: + ... + + @overload + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Union[str, GroupByAggregate], + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> AggregateGroupByReturn: + ... + + def hybrid( + self, + query: Optional[str], + *, + alpha: NUMBER = 0.7, + vector: Optional[List[float]] = None, + query_properties: Optional[List[str]] = None, + object_limit: Optional[int] = None, + filters: Optional[_Filters] = None, + group_by: Optional[Union[str, GroupByAggregate]] = None, + target_vector: Optional[str] = None, + total_count: bool = True, + return_metrics: Optional[PropertiesMetrics] = None, + ) -> Union[AggregateReturn, AggregateGroupByReturn]: + """Aggregate metrics over all the objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. + + Arguments: + `query` + The keyword-based query to search for, REQUIRED. If query and vector are both None, a normal search will be performed. + `alpha` + The weight of the BM25 score. If not specified, the default weight specified by the server is used. + `vector` + The specific vector to search for. If not specified, the query is vectorized and used in the similarity search. + `query_properties` + The properties to search in. If not specified, all properties are searched. + `object_limit` + The maximum number of objects to return from the hybrid vector search prior to the aggregation. + `filters` + The filters to apply to the search. + `group_by` + How to group the aggregation by. + `total_count` + Whether to include the total number of objects that match the query in the response. + `return_metrics` + A list of property metrics to aggregate together after the text search. + + Returns: + Depending on the presence of the `group_by` argument, either a `AggregateReturn` object or a `AggregateGroupByReturn that includes the aggregation objects. + + Raises: + `weaviate.exceptions.WeaviateQueryError`: + If an error occurs while performing the query against Weaviate. + `weaviate.exceptions.WeaviateInvalidInputError`: + If any of the input arguments are of the wrong type. + """ + if group_by is not None and self._connection._weaviate_version.is_lower_than(1, 25, 0): + raise WeaviateNotImplementedError( + "Hybrid aggregation", self._connection.server_version, "1.25.0" + ) + return_metrics = ( + return_metrics + if (return_metrics is None or isinstance(return_metrics, list)) + else [return_metrics] + ) + builder = self._base(return_metrics, filters, total_count) + builder = self._add_hybrid_to_builder( + builder, query, alpha, vector, query_properties, object_limit, target_vector + ) + builder = self._add_groupby_to_builder(builder, group_by) + res = self._do(builder) + return ( + self._to_aggregate_result(res, return_metrics) + if group_by is None + else self._to_group_by_result(res, return_metrics) + ) diff --git a/weaviate/collections/aggregations/near_image.py b/weaviate/collections/aggregations/near_image.py index f009aba33..591ff9e28 100644 --- a/weaviate/collections/aggregations/near_image.py +++ b/weaviate/collections/aggregations/near_image.py @@ -99,7 +99,7 @@ def near_image( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_image( + builder = self._add_near_image_to_builder( builder, near_image, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/aggregations/near_object.py b/weaviate/collections/aggregations/near_object.py index df3d1f6fd..c1f18d3c0 100644 --- a/weaviate/collections/aggregations/near_object.py +++ b/weaviate/collections/aggregations/near_object.py @@ -97,7 +97,7 @@ def near_object( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_object( + builder = self._add_near_object_to_builder( builder, near_object, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/aggregations/near_text.py b/weaviate/collections/aggregations/near_text.py index 9a8b41ad1..bb17f6d23 100644 --- a/weaviate/collections/aggregations/near_text.py +++ b/weaviate/collections/aggregations/near_text.py @@ -108,7 +108,7 @@ def near_text( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_text( + builder = self._add_near_text_to_builder( builder=builder, query=query, certainty=certainty, diff --git a/weaviate/collections/aggregations/near_vector.py b/weaviate/collections/aggregations/near_vector.py index a8346576a..0b6d02bb9 100644 --- a/weaviate/collections/aggregations/near_vector.py +++ b/weaviate/collections/aggregations/near_vector.py @@ -97,7 +97,7 @@ def near_vector( ) builder = self._base(return_metrics, filters, total_count) builder = self._add_groupby_to_builder(builder, group_by) - builder = self._add_near_vector( + builder = self._add_near_vector_to_builder( builder, near_vector, certainty, distance, object_limit, target_vector ) res = self._do(builder) diff --git a/weaviate/collections/queries/hybrid/generate.py b/weaviate/collections/queries/hybrid/generate.py index cad94fc0b..d647c5a32 100644 --- a/weaviate/collections/queries/hybrid/generate.py +++ b/weaviate/collections/queries/hybrid/generate.py @@ -24,7 +24,7 @@ def hybrid( single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/generate.pyi b/weaviate/collections/queries/hybrid/generate.pyi index f803ada7e..56a19fa7c 100644 --- a/weaviate/collections/queries/hybrid/generate.pyi +++ b/weaviate/collections/queries/hybrid/generate.pyi @@ -28,7 +28,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -51,7 +51,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -74,7 +74,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -97,7 +97,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -120,7 +120,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -143,7 +143,7 @@ class _HybridGenerate(Generic[Properties, References], _BaseQuery[Properties, Re single_prompt: Optional[str] = None, grouped_task: Optional[str] = None, grouped_properties: Optional[List[str]] = None, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/query.py b/weaviate/collections/queries/hybrid/query.py index c8961ca3b..c8caa12a2 100644 --- a/weaviate/collections/queries/hybrid/query.py +++ b/weaviate/collections/queries/hybrid/query.py @@ -20,7 +20,7 @@ def hybrid( self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/collections/queries/hybrid/query.pyi b/weaviate/collections/queries/hybrid/query.pyi index 141312ea4..1b6013cd2 100644 --- a/weaviate/collections/queries/hybrid/query.pyi +++ b/weaviate/collections/queries/hybrid/query.pyi @@ -25,7 +25,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -45,7 +45,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -65,7 +65,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -85,7 +85,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -105,7 +105,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, @@ -125,7 +125,7 @@ class _HybridQuery(Generic[Properties, References], _BaseQuery[Properties, Refer self, query: Optional[str], *, - alpha: NUMBER = 0.5, + alpha: NUMBER = 0.7, vector: Optional[HybridVectorType] = None, query_properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, diff --git a/weaviate/gql/aggregate.py b/weaviate/gql/aggregate.py index 71f6198c4..445e46b99 100644 --- a/weaviate/gql/aggregate.py +++ b/weaviate/gql/aggregate.py @@ -3,13 +3,11 @@ """ import json +from dataclasses import dataclass from typing import List, Optional, Union from weaviate.connect import Connection, ConnectionV4 -from weaviate.util import ( - _capitalize_first_letter, - file_encoder_b64, -) +from weaviate.util import _capitalize_first_letter, file_encoder_b64, _sanitize_str from .filter import ( Where, GraphQL, @@ -27,6 +25,38 @@ ) +@dataclass +class Hybrid: + query: Optional[str] + alpha: Optional[float] + vector: Optional[List[float]] + properties: Optional[List[str]] + target_vectors: Optional[List[str]] + + def __init__(self, content: dict) -> None: + self.query = content.get("query") + self.alpha = content.get("alpha") + self.vector = content.get("vector") + self.properties = content.get("properties") + self.target_vectors = content.get("targetVectors") + + def __str__(self) -> str: + ret = "" + if self.query is not None: + ret += f"query: {_sanitize_str(self.query)}" + if self.vector is not None: + ret += f", vector: {self.vector}" + if self.alpha is not None: + ret += f", alpha: {self.alpha}" + if self.properties is not None and len(self.properties) > 0: + props = '","'.join(self.properties) + ret += f', properties: ["{props}"]' + if self.target_vectors is not None: + target_vectors = '","'.join(self.target_vectors) + ret += f', targetVectors: ["{target_vectors}"]' + return "hybrid:{" + ret + "}" + + class AggregateBuilder(GraphQL): """ AggregateBuilder class used to aggregate Weaviate objects. @@ -55,6 +85,7 @@ def __init__(self, class_name: str, connection: Union[Connection, ConnectionV4]) self._near: Optional[Filter] = None self._tenant: Optional[str] = None self._limit: Optional[int] = None + self._hybrid: Optional[Hybrid] = None def with_tenant(self, tenant: str) -> "AggregateBuilder": """Sets a tenant for the query.""" @@ -209,6 +240,20 @@ def with_where(self, content: dict) -> "AggregateBuilder": self._uses_filter = True return self + def with_hybrid(self, content: dict) -> "AggregateBuilder": + """Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm. + + Parameters + ---------- + content : dict + The content of the `hybrid` filter to set. + """ + if self._near is not None: + raise AttributeError("Cannot use 'hybrid' and 'near' filters simultaneously.") + self._hybrid = Hybrid(content) + self._uses_filter = True + return self + def with_group_by_filter(self, properties: List[str]) -> "AggregateBuilder": """ Add a group by filter to the query. Might requires the user to set @@ -308,6 +353,8 @@ def with_near_text(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearText(content) self._uses_filter = True return self @@ -373,6 +420,8 @@ def with_near_vector(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearVector(content) self._uses_filter = True return self @@ -423,6 +472,8 @@ def with_near_object(self, content: dict) -> "AggregateBuilder": if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") self._near = NearObject(content, is_server_version_14) self._uses_filter = True return self @@ -534,6 +585,8 @@ def with_near_image(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content["image"] = file_encoder_b64(content["image"]) self._near = NearImage(content) @@ -648,6 +701,8 @@ def with_near_audio(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearAudio(content) @@ -762,6 +817,8 @@ def with_near_video(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearVideo(content) @@ -876,6 +933,8 @@ def with_near_depth(self, content: dict, encode: bool = True) -> "AggregateBuild "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearDepth(content) @@ -989,6 +1048,8 @@ def with_near_thermal(self, content: dict, encode: bool = True) -> "AggregateBui "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearThermal(content) @@ -1103,6 +1164,8 @@ def with_near_imu(self, content: dict, encode: bool = True) -> "AggregateBuilder "Cannot use multiple 'near' filters, or a 'near' filter along" " with a 'ask' filter!" ) + if self._hybrid is not None: + raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") if encode: content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) self._near = NearIMU(content) @@ -1137,7 +1200,8 @@ def build(self) -> str: query += f'tenant: "{self._tenant}"' if self._limit is not None: query += f"limit: {self._limit}" - + if self._hybrid is not None: + query += str(self._hybrid) query += ")" # Body