Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate.hybrid method using GQL #992

Merged
merged 15 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions integration/test_collection_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
import uuid
from datetime import datetime, timezone
from typing import Union

import pytest
from _pytest.fixtures import SubRequest
Expand All @@ -18,7 +19,11 @@
)
from weaviate.collections.classes.config import DataType, Property, ReferenceProperty, Configure
from weaviate.collections.classes.filters import Filter, _Filters
from weaviate.exceptions import WeaviateInvalidInputError, WeaviateQueryError
from weaviate.exceptions import (
WeaviateInvalidInputError,
WeaviateQueryError,
WeaviateNotImplementedError,
)
from weaviate.util import file_encoder_b64

from weaviate.collections.classes.grpc import Move
Expand Down Expand Up @@ -290,6 +295,132 @@ def test_near_object_missing_param(collection_factory: CollectionFactory) -> Non
)


@pytest.mark.parametrize(
"option,expected_len",
[
({"object_limit": 1}, 1),
({"object_limit": 2}, 2),
],
)
def test_hybrid_aggregation(
collection_factory: CollectionFactory, option: dict, expected_len: int
) -> None:
collection = collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
text_1 = "some text"
text_2 = "nothing like the other one at all, not even a little bit"
uuid = collection.data.insert({"text": text_1})
obj = collection.query.fetch_object_by_id(uuid, include_vector=True)
assert "default" in obj.vector
collection.data.insert({"text": text_2})
res: AggregateReturn = collection.aggregate.hybrid(
dirkkul marked this conversation as resolved.
Show resolved Hide resolved
None,
alpha=1,
vector=obj.vector["default"],
return_metrics=[
Metrics("text").text(count=True, top_occurrences_count=True, top_occurrences_value=True)
],
**option,
)
assert isinstance(res.properties["text"], AggregateText)
assert res.properties["text"].count == expected_len
assert len(res.properties["text"].top_occurrences) == expected_len
assert text_1 in [
top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences
]
if expected_len == 2:
assert text_2 in [
top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences
]
else:
assert text_2 not in [
top_occurrence.value for top_occurrence in res.properties["text"].top_occurrences
]


@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)])
def test_hybrid_aggregation_group_by(
collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate]
) -> None:
collection = collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

text_1 = "some text"
text_2 = "nothing like the other one at all, not even a little bit"
collection.data.insert({"text": text_1})
collection.data.insert({"text": text_2})

querier = lambda: collection.aggregate.hybrid(
"text",
alpha=0,
query_properties=["text"],
group_by=group_by,
total_count=True,
object_limit=2, # has no effect due to alpha=0
)
if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateNotImplementedError):
querier()
return

res = querier()
assert res.groups[0].grouped_by.prop == "text"
assert res.groups[0].grouped_by.value == "some text"
assert res.groups[0].total_count == 1


@pytest.mark.parametrize("group_by", ["text", GroupByAggregate(prop="text", limit=1)])
def test_hybrid_aggregation_group_by_with_named_vectors(
collection_factory: CollectionFactory, group_by: Union[str, GroupByAggregate]
) -> None:
dummy = collection_factory("dummy")
collection_maker = lambda: collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=[
Configure.NamedVectors.text2vec_contextionary(
name="all", vectorize_collection_name=False
)
],
)
if dummy._connection._weaviate_version.is_lower_than(1, 24, 0):
with pytest.raises(WeaviateInvalidInputError):
collection_maker()
return

collection = collection_maker()
text_1 = "some text"
text_2 = "nothing like the other one at all, not even a little bit"
collection.data.insert({"text": text_1})
collection.data.insert({"text": text_2})

querier = lambda: collection.aggregate.hybrid(
"text",
alpha=0,
query_properties=["text"],
group_by=group_by,
total_count=True,
object_limit=2, # has no effect due to alpha=0
target_vector="all",
)
if dummy._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateNotImplementedError):
querier()
return

res = querier()
assert res.groups[0].grouped_by.prop == "text"
assert res.groups[0].grouped_by.value == "some text"
assert res.groups[0].total_count == 1


@pytest.mark.parametrize(
"option,expected_len",
[
Expand All @@ -304,12 +435,14 @@ def test_near_object_missing_param(collection_factory: CollectionFactory) -> Non
def test_near_vector_aggregation(
collection_factory: CollectionFactory, option: dict, expected_len: int
) -> None:
collection = collection_factory(
collection_maker = lambda: collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

collection = collection_maker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a blocker, but what is the point of this change?

text_1 = "some text"
text_2 = "nothing like the other one at all, not even a little bit"
uuid = collection.data.insert({"text": text_1})
Expand Down
5 changes: 3 additions & 2 deletions weaviate/collections/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from weaviate.collections.aggregations.over_all import _OverAll
from weaviate.collections.aggregations.hybrid import _Hybrid
from weaviate.collections.aggregations.near_image import _NearImage
from weaviate.collections.aggregations.near_object import _NearObject
from weaviate.collections.aggregations.near_text import _NearText
from weaviate.collections.aggregations.near_vector import _NearVector
from weaviate.collections.aggregations.over_all import _OverAll


class _AggregateCollection(_OverAll, _NearImage, _NearObject, _NearText, _NearVector):
class _AggregateCollection(_Hybrid, _NearImage, _NearObject, _NearText, _NearVector, _OverAll):
pass
38 changes: 32 additions & 6 deletions weaviate/collections/aggregations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __init__(
consistency_level: Optional[ConsistencyLevel],
tenant: Optional[str],
):
self.__connection = connection
self._connection = connection
self.__name = name
self._tenant = tenant
self._consistency_level = consistency_level

def _query(self) -> AggregateBuilder:
return AggregateBuilder(self.__name, self.__connection)
return AggregateBuilder(self.__name, self._connection)

def _to_aggregate_result(
self, response: dict, metrics: Optional[List[_Metrics]]
Expand Down Expand Up @@ -235,7 +235,33 @@ def _parse_near_options(
)

@staticmethod
def _add_near_image(
def _add_hybrid_to_builder(
builder: AggregateBuilder,
query: Optional[str],
alpha: Optional[NUMBER],
vector: Optional[List[float]],
query_properties: Optional[List[str]],
object_limit: Optional[int],
target_vector: Optional[str],
) -> AggregateBuilder:
payload: dict = {}
if query is not None:
payload["query"] = query
if alpha is not None:
payload["alpha"] = alpha
if vector is not None:
payload["vector"] = vector
if query_properties is not None:
payload["properties"] = query_properties
if target_vector is not None:
payload["targetVectors"] = [target_vector]
builder = builder.with_hybrid(payload)
if object_limit is not None:
builder = builder.with_object_limit(object_limit)
return builder

@staticmethod
def _add_near_image_to_builder(
builder: AggregateBuilder,
near_image: Union[str, pathlib.Path, io.BufferedReader],
certainty: Optional[NUMBER],
Expand Down Expand Up @@ -265,7 +291,7 @@ def _add_near_image(
return builder

@staticmethod
def _add_near_object(
def _add_near_object_to_builder(
builder: AggregateBuilder,
near_object: UUID,
certainty: Optional[NUMBER],
Expand Down Expand Up @@ -293,7 +319,7 @@ def _add_near_object(
return builder

@staticmethod
def _add_near_text(
def _add_near_text_to_builder(
builder: AggregateBuilder,
query: Union[List[str], str],
certainty: Optional[NUMBER],
Expand Down Expand Up @@ -334,7 +360,7 @@ def _add_near_text(
return builder

@staticmethod
def _add_near_vector(
def _add_near_vector_to_builder(
builder: AggregateBuilder,
near_vector: List[float],
certainty: Optional[NUMBER],
Expand Down
114 changes: 114 additions & 0 deletions weaviate/collections/aggregations/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import List, Literal, Optional, Union, overload

from weaviate.collections.aggregations.base import _Aggregate
from weaviate.collections.classes.aggregate import (
PropertiesMetrics,
AggregateReturn,
AggregateGroupByReturn,
GroupByAggregate,
)
from weaviate.collections.classes.filters import _Filters
from weaviate.exceptions import WeaviateNotImplementedError
from weaviate.types import NUMBER


class _Hybrid(_Aggregate):
@overload
def hybrid(
self,
query: Optional[str],
*,
alpha: NUMBER = 0.7,
vector: Optional[List[float]] = None,
query_properties: Optional[List[str]] = None,
object_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
target_vector: Optional[str] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateReturn:
...

@overload
def hybrid(
self,
query: Optional[str],
*,
alpha: NUMBER = 0.7,
vector: Optional[List[float]] = None,
query_properties: Optional[List[str]] = None,
object_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Union[str, GroupByAggregate],
target_vector: Optional[str] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateGroupByReturn:
...

def hybrid(
self,
query: Optional[str],
*,
alpha: NUMBER = 0.7,
vector: Optional[List[float]] = None,
query_properties: Optional[List[str]] = None,
object_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[Union[str, GroupByAggregate]] = None,
target_vector: Optional[str] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
"""Aggregate metrics over all the objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity.

Arguments:
`query`
The keyword-based query to search for, REQUIRED. If query and vector are both None, a normal search will be performed.
`alpha`
The weight of the BM25 score. If not specified, the default weight specified by the server is used.
`vector`
The specific vector to search for. If not specified, the query is vectorized and used in the similarity search.
`query_properties`
The properties to search in. If not specified, all properties are searched.
`object_limit`
The maximum number of objects to return from the hybrid vector search prior to the aggregation.
`filters`
The filters to apply to the search.
`group_by`
How to group the aggregation by.
`total_count`
Whether to include the total number of objects that match the query in the response.
`return_metrics`
A list of property metrics to aggregate together after the text search.

Returns:
Depending on the presence of the `group_by` argument, either a `AggregateReturn` object or a `AggregateGroupByReturn that includes the aggregation objects.

Raises:
`weaviate.exceptions.WeaviateQueryError`:
If an error occurs while performing the query against Weaviate.
`weaviate.exceptions.WeaviateInvalidInputError`:
If any of the input arguments are of the wrong type.
"""
if group_by is not None and self._connection._weaviate_version.is_lower_than(1, 25, 0):
raise WeaviateNotImplementedError(
"Hybrid aggregation", self._connection.server_version, "1.25.0"
)
return_metrics = (
return_metrics
if (return_metrics is None or isinstance(return_metrics, list))
else [return_metrics]
)
builder = self._base(return_metrics, filters, total_count)
builder = self._add_hybrid_to_builder(
builder, query, alpha, vector, query_properties, object_limit, target_vector
)
builder = self._add_groupby_to_builder(builder, group_by)
res = self._do(builder)
return (
self._to_aggregate_result(res, return_metrics)
if group_by is None
else self._to_group_by_result(res, return_metrics)
)
2 changes: 1 addition & 1 deletion weaviate/collections/aggregations/near_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def near_image(
)
builder = self._base(return_metrics, filters, total_count)
builder = self._add_groupby_to_builder(builder, group_by)
builder = self._add_near_image(
builder = self._add_near_image_to_builder(
builder, near_image, certainty, distance, object_limit, target_vector
)
res = self._do(builder)
Expand Down
2 changes: 1 addition & 1 deletion weaviate/collections/aggregations/near_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def near_object(
)
builder = self._base(return_metrics, filters, total_count)
builder = self._add_groupby_to_builder(builder, group_by)
builder = self._add_near_object(
builder = self._add_near_object_to_builder(
builder, near_object, certainty, distance, object_limit, target_vector
)
res = self._do(builder)
Expand Down
Loading