Skip to content

Commit e82c56e

Browse files
committed
Added support for ADDSCORES modifier (#3329)
* Added support for ADDSCORES modifier * Fixed codestyle issues * More codestyle fixes * Updated test cases and testing image to represent latest * Codestyle issues * Added handling for dict responses
1 parent 89a2898 commit e82c56e

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

redis/commands/search/aggregation.py

+11
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
111111
self._verbatim = False
112112
self._cursor = []
113113
self._dialect = None
114+
self._add_scores = False
114115

115116
def load(self, *fields: List[str]) -> "AggregateRequest":
116117
"""
@@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
292293
self._with_schema = True
293294
return self
294295

296+
def add_scores(self) -> "AggregateRequest":
297+
"""
298+
If set, includes the score as an ordinary field of the row.
299+
"""
300+
self._add_scores = True
301+
return self
302+
295303
def verbatim(self) -> "AggregateRequest":
296304
self._verbatim = True
297305
return self
@@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
315323
if self._verbatim:
316324
ret.append("VERBATIM")
317325

326+
if self._add_scores:
327+
ret.append("ADDSCORES")
328+
318329
if self._cursor:
319330
ret += self._cursor
320331

tests/test_asyncio/test_search.py

+26
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,32 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
15311531
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]
15321532

15331533

1534+
@pytest.mark.redismod
1535+
@skip_ifmodversion_lt("2.10.05", "search")
1536+
async def test_aggregations_add_scores(decoded_r: redis.Redis):
1537+
assert await decoded_r.ft().create_index(
1538+
(
1539+
TextField("name", sortable=True, weight=5.0),
1540+
NumericField("age", sortable=True),
1541+
)
1542+
)
1543+
1544+
assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"})
1545+
assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"})
1546+
1547+
req = aggregations.AggregateRequest("*").add_scores()
1548+
res = await decoded_r.ft().aggregate(req)
1549+
1550+
if isinstance(res, dict):
1551+
assert len(res["results"]) == 2
1552+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1553+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1554+
else:
1555+
assert len(res.rows) == 2
1556+
assert res.rows[0] == ["__score", "0.2"]
1557+
assert res.rows[1] == ["__score", "0.2"]
1558+
1559+
15341560
@pytest.mark.redismod
15351561
@skip_if_redis_enterprise()
15361562
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

+26
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,32 @@ def test_aggregations_filter(client):
14431443
assert res["results"][1]["extra_attributes"] == {"age": "25"}
14441444

14451445

1446+
@pytest.mark.redismod
1447+
@skip_ifmodversion_lt("2.10.05", "search")
1448+
def test_aggregations_add_scores(client):
1449+
client.ft().create_index(
1450+
(
1451+
TextField("name", sortable=True, weight=5.0),
1452+
NumericField("age", sortable=True),
1453+
)
1454+
)
1455+
1456+
client.hset("doc1", mapping={"name": "bar", "age": "25"})
1457+
client.hset("doc2", mapping={"name": "foo", "age": "19"})
1458+
1459+
req = aggregations.AggregateRequest("*").add_scores()
1460+
res = client.ft().aggregate(req)
1461+
1462+
if isinstance(res, dict):
1463+
assert len(res["results"]) == 2
1464+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1465+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1466+
else:
1467+
assert len(res.rows) == 2
1468+
assert res.rows[0] == ["__score", "0.2"]
1469+
assert res.rows[1] == ["__score", "0.2"]
1470+
1471+
14461472
@pytest.mark.redismod
14471473
@skip_ifmodversion_lt("2.0.0", "search")
14481474
def test_index_definition(client):

0 commit comments

Comments
 (0)