diff --git a/integration/test_collection_aggregate.py b/integration/test_collection_aggregate.py index 44178b15c..6733321e3 100644 --- a/integration/test_collection_aggregate.py +++ b/integration/test_collection_aggregate.py @@ -71,7 +71,20 @@ def test_simple_aggregation(collection_factory: CollectionFactory) -> None: assert res.properties["text"].count == 1 -def test_aggregation_with_limit(collection_factory: CollectionFactory) -> None: +def test_aggregation_top_occurence_with_limit(collection_factory: CollectionFactory) -> None: + collection = collection_factory(properties=[Property(name="text", data_type=DataType.TEXT)]) + collection.data.insert({"text": "one"}) + collection.data.insert({"text": "one"}) + collection.data.insert({"text": "two"}) + res = collection.aggregate.over_all( + return_metrics=[Metrics("text").text(min_occurrences=1)], + ) + assert isinstance(res.properties["text"], AggregateText) + assert len(res.properties["text"].top_occurrences) == 1 + assert res.properties["text"].top_occurrences[0].count == 2 + + +def test_aggregation_groupby_with_limit(collection_factory: CollectionFactory) -> None: collection = collection_factory(properties=[Property(name="text", data_type=DataType.TEXT)]) collection.data.insert({"text": "one"}) collection.data.insert({"text": "two"}) diff --git a/weaviate/collections/classes/aggregate.py b/weaviate/collections/classes/aggregate.py index 1ec81b6b6..38243b574 100644 --- a/weaviate/collections/classes/aggregate.py +++ b/weaviate/collections/classes/aggregate.py @@ -138,13 +138,15 @@ class _MetricsBase(BaseModel): class _MetricsText(_MetricsBase): top_occurrences_count: bool top_occurrences_value: bool + min_occurrences: Optional[int] def to_gql(self) -> str: + limit = f"(limit: {self.min_occurrences})" if self.min_occurrences is not None else "" body = " ".join( [ "count" if self.count else "", ( - "topOccurrences {" + "topOccurrences" + limit + " {" if self.top_occurrences_count or self.top_occurrences_value else "" ), @@ -275,6 +277,7 @@ def text( count: bool = False, top_occurrences_count: bool = False, top_occurrences_value: bool = False, + min_occurrences: Optional[int] = None, ) -> _MetricsText: """Define the metrics to be returned for a TEXT or TEXT_ARRAY property when aggregating over a collection. @@ -287,6 +290,8 @@ def text( Whether to include the number of the top occurrences of a property's value. `top_occurrences_value` Whether to include the value of the top occurrences of a property's value. + `min_occurrences` + Only include entries with more occurrences than the given limit. Returns: A `_MetricsStr` object that includes the metrics to be returned. @@ -300,6 +305,7 @@ def text( count=count, top_occurrences_count=top_occurrences_count, top_occurrences_value=top_occurrences_value, + min_occurrences=min_occurrences, ) def integer(