Skip to content

Commit

Permalink
Added support for ADDSCORES modifier (#3329)
Browse files Browse the repository at this point in the history
* Added support for ADDSCORES modifier

* Fixed codestyle issues

* More codestyle fixes

* Updated test cases and testing image to represent latest

* Codestyle issues

* Added handling for dict responses
  • Loading branch information
vladvildanov committed Sep 27, 2024
1 parent c626da2 commit bc93b54
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ env:
# this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665
COVERAGE_CORE: sysmon
REDIS_IMAGE: redis:7.4-rc2
REDIS_STACK_IMAGE: redis/redis-stack-server:7.4.0-rc2
REDIS_STACK_IMAGE: redis/redis-stack-server:latest

jobs:
dependency-audit:
Expand Down
11 changes: 11 additions & 0 deletions redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
self._verbatim = False
self._cursor = []
self._dialect = None
self._add_scores = False

def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Expand Down Expand Up @@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
self._with_schema = True
return self

def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self

def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
Expand All @@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
if self._verbatim:
ret.append("VERBATIM")

if self._add_scores:
ret.append("ADDSCORES")

if self._cursor:
ret += self._cursor

Expand Down
26 changes: 26 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,32 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_add_scores(decoded_r: redis.Redis):
assert await decoded_r.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
NumericField("age", sortable=True),
)
)

assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"})
assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"})

req = aggregations.AggregateRequest("*").add_scores()
res = await decoded_r.ft().aggregate(req)

if isinstance(res, dict):
assert len(res["results"]) == 2
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
else:
assert len(res.rows) == 2
assert res.rows[0] == ["__score", "0.2"]
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_if_redis_enterprise()
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,32 @@ def test_aggregations_filter(client):
assert res["results"][1]["extra_attributes"] == {"age": "25"}


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
def test_aggregations_add_scores(client):
client.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
NumericField("age", sortable=True),
)
)

client.hset("doc1", mapping={"name": "bar", "age": "25"})
client.hset("doc2", mapping={"name": "foo", "age": "19"})

req = aggregations.AggregateRequest("*").add_scores()
res = client.ft().aggregate(req)

if isinstance(res, dict):
assert len(res["results"]) == 2
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
else:
assert len(res.rows) == 2
assert res.rows[0] == ["__score", "0.2"]
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_index_definition(client):
Expand Down

0 comments on commit bc93b54

Please sign in to comment.