Skip to content

Commit 1475e5c

Browse files
authored
Add async supoort for SEARCH commands (#2096)
* Add async supoort for SEARCH commands * linters * linters * linters * linters * linters
1 parent c29d158 commit 1475e5c

File tree

10 files changed

+6865
-5
lines changed

10 files changed

+6865
-5
lines changed

redis/asyncio/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
)
4242
from redis.commands import (
4343
AsyncCoreCommands,
44+
AsyncRedisModuleCommands,
4445
AsyncSentinelCommands,
45-
RedisModuleCommands,
4646
list_or_args,
4747
)
4848
from redis.compat import Protocol, TypedDict
@@ -81,7 +81,7 @@ async def __call__(self, response: Any, **kwargs):
8181

8282

8383
class Redis(
84-
AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
84+
AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
8585
):
8686
"""
8787
Implementation of the Redis protocol.

redis/commands/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .core import AsyncCoreCommands, CoreCommands
33
from .helpers import list_or_args
44
from .parser import CommandsParser
5-
from .redismodules import RedisModuleCommands
5+
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
66
from .sentinel import AsyncSentinelCommands, SentinelCommands
77

88
__all__ = [
@@ -11,6 +11,7 @@
1111
"AsyncCoreCommands",
1212
"CoreCommands",
1313
"list_or_args",
14+
"AsyncRedisModuleCommands",
1415
"RedisModuleCommands",
1516
"AsyncSentinelCommands",
1617
"SentinelCommands",

redis/commands/redismodules.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,13 @@ def graph(self, index_name="idx"):
8181

8282
g = Graph(client=self, name=index_name)
8383
return g
84+
85+
86+
class AsyncRedisModuleCommands(RedisModuleCommands):
87+
def ft(self, index_name="idx"):
88+
"""Access the search namespace, providing support for redis search."""
89+
90+
from .search import AsyncSearch
91+
92+
s = AsyncSearch(client=self, index_name=index_name)
93+
return s

redis/commands/search/__init__.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import redis
22

3-
from .commands import SearchCommands
3+
from .commands import AsyncSearchCommands, SearchCommands
44

55

66
class Search(SearchCommands):
@@ -112,5 +112,67 @@ def pipeline(self, transaction=True, shard_hint=None):
112112
return p
113113

114114

115+
class AsyncSearch(Search, AsyncSearchCommands):
116+
class BatchIndexer(Search.BatchIndexer):
117+
"""
118+
A batch indexer allows you to automatically batch
119+
document indexing in pipelines, flushing it every N documents.
120+
"""
121+
122+
async def add_document(
123+
self,
124+
doc_id,
125+
nosave=False,
126+
score=1.0,
127+
payload=None,
128+
replace=False,
129+
partial=False,
130+
no_create=False,
131+
**fields,
132+
):
133+
"""
134+
Add a document to the batch query
135+
"""
136+
self.client._add_document(
137+
doc_id,
138+
conn=self._pipeline,
139+
nosave=nosave,
140+
score=score,
141+
payload=payload,
142+
replace=replace,
143+
partial=partial,
144+
no_create=no_create,
145+
**fields,
146+
)
147+
self.current_chunk += 1
148+
self.total += 1
149+
if self.current_chunk >= self.chunk_size:
150+
await self.commit()
151+
152+
async def commit(self):
153+
"""
154+
Manually commit and flush the batch indexing query
155+
"""
156+
await self._pipeline.execute()
157+
self.current_chunk = 0
158+
159+
def pipeline(self, transaction=True, shard_hint=None):
160+
"""Creates a pipeline for the SEARCH module, that can be used for executing
161+
SEARCH commands, as well as classic core commands.
162+
"""
163+
p = AsyncPipeline(
164+
connection_pool=self.client.connection_pool,
165+
response_callbacks=self.MODULE_CALLBACKS,
166+
transaction=transaction,
167+
shard_hint=shard_hint,
168+
)
169+
p.index_name = self.index_name
170+
return p
171+
172+
115173
class Pipeline(SearchCommands, redis.client.Pipeline):
116174
"""Pipeline for the module."""
175+
176+
177+
class AsyncPipeline(AsyncSearchCommands, redis.asyncio.client.Pipeline):
178+
"""AsyncPipeline for the module."""

redis/commands/search/commands.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,245 @@ def syndump(self):
857857
""" # noqa
858858
raw = self.execute_command(SYNDUMP_CMD, self.index_name)
859859
return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)}
860+
861+
862+
class AsyncSearchCommands(SearchCommands):
863+
async def info(self):
864+
"""
865+
Get info an stats about the the current index, including the number of
866+
documents, memory consumption, etc
867+
868+
For more information https://oss.redis.com/redisearch/Commands/#ftinfo
869+
"""
870+
871+
res = await self.execute_command(INFO_CMD, self.index_name)
872+
it = map(to_string, res)
873+
return dict(zip(it, it))
874+
875+
async def search(
876+
self,
877+
query: Union[str, Query],
878+
query_params: Dict[str, Union[str, int, float]] = None,
879+
):
880+
"""
881+
Search the index for a given query, and return a result of documents
882+
883+
### Parameters
884+
885+
- **query**: the search query. Either a text for simple queries with
886+
default parameters, or a Query object for complex queries.
887+
See RediSearch's documentation on query format
888+
889+
For more information: https://oss.redis.com/redisearch/Commands/#ftsearch
890+
""" # noqa
891+
args, query = self._mk_query_args(query, query_params=query_params)
892+
st = time.time()
893+
res = await self.execute_command(SEARCH_CMD, *args)
894+
895+
if isinstance(res, Pipeline):
896+
return res
897+
898+
return Result(
899+
res,
900+
not query._no_content,
901+
duration=(time.time() - st) * 1000.0,
902+
has_payload=query._with_payloads,
903+
with_scores=query._with_scores,
904+
)
905+
906+
async def aggregate(
907+
self,
908+
query: Union[str, Query],
909+
query_params: Dict[str, Union[str, int, float]] = None,
910+
):
911+
"""
912+
Issue an aggregation query.
913+
914+
### Parameters
915+
916+
**query**: This can be either an `AggregateRequest`, or a `Cursor`
917+
918+
An `AggregateResult` object is returned. You can access the rows from
919+
its `rows` property, which will always yield the rows of the result.
920+
921+
For more information: https://oss.redis.com/redisearch/Commands/#ftaggregate
922+
""" # noqa
923+
if isinstance(query, AggregateRequest):
924+
has_cursor = bool(query._cursor)
925+
cmd = [AGGREGATE_CMD, self.index_name] + query.build_args()
926+
elif isinstance(query, Cursor):
927+
has_cursor = True
928+
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
929+
else:
930+
raise ValueError("Bad query", query)
931+
if query_params is not None:
932+
cmd += self.get_params_args(query_params)
933+
934+
raw = await self.execute_command(*cmd)
935+
return self._get_aggregate_result(raw, query, has_cursor)
936+
937+
async def spellcheck(self, query, distance=None, include=None, exclude=None):
938+
"""
939+
Issue a spellcheck query
940+
941+
### Parameters
942+
943+
**query**: search query.
944+
**distance***: the maximal Levenshtein distance for spelling
945+
suggestions (default: 1, max: 4).
946+
**include**: specifies an inclusion custom dictionary.
947+
**exclude**: specifies an exclusion custom dictionary.
948+
949+
For more information: https://oss.redis.com/redisearch/Commands/#ftspellcheck
950+
""" # noqa
951+
cmd = [SPELLCHECK_CMD, self.index_name, query]
952+
if distance:
953+
cmd.extend(["DISTANCE", distance])
954+
955+
if include:
956+
cmd.extend(["TERMS", "INCLUDE", include])
957+
958+
if exclude:
959+
cmd.extend(["TERMS", "EXCLUDE", exclude])
960+
961+
raw = await self.execute_command(*cmd)
962+
963+
corrections = {}
964+
if raw == 0:
965+
return corrections
966+
967+
for _correction in raw:
968+
if isinstance(_correction, int) and _correction == 0:
969+
continue
970+
971+
if len(_correction) != 3:
972+
continue
973+
if not _correction[2]:
974+
continue
975+
if not _correction[2][0]:
976+
continue
977+
978+
corrections[_correction[1]] = [
979+
{"score": _item[0], "suggestion": _item[1]} for _item in _correction[2]
980+
]
981+
982+
return corrections
983+
984+
async def config_set(self, option, value):
985+
"""Set runtime configuration option.
986+
987+
### Parameters
988+
989+
- **option**: the name of the configuration option.
990+
- **value**: a value for the configuration option.
991+
992+
For more information: https://oss.redis.com/redisearch/Commands/#ftconfig
993+
""" # noqa
994+
cmd = [CONFIG_CMD, "SET", option, value]
995+
raw = await self.execute_command(*cmd)
996+
return raw == "OK"
997+
998+
async def config_get(self, option):
999+
"""Get runtime configuration option value.
1000+
1001+
### Parameters
1002+
1003+
- **option**: the name of the configuration option.
1004+
1005+
For more information: https://oss.redis.com/redisearch/Commands/#ftconfig
1006+
""" # noqa
1007+
cmd = [CONFIG_CMD, "GET", option]
1008+
res = {}
1009+
raw = await self.execute_command(*cmd)
1010+
if raw:
1011+
for kvs in raw:
1012+
res[kvs[0]] = kvs[1]
1013+
return res
1014+
1015+
async def load_document(self, id):
1016+
"""
1017+
Load a single document by id
1018+
"""
1019+
fields = await self.client.hgetall(id)
1020+
f2 = {to_string(k): to_string(v) for k, v in fields.items()}
1021+
fields = f2
1022+
1023+
try:
1024+
del fields["id"]
1025+
except KeyError:
1026+
pass
1027+
1028+
return Document(id=id, **fields)
1029+
1030+
async def sugadd(self, key, *suggestions, **kwargs):
1031+
"""
1032+
Add suggestion terms to the AutoCompleter engine. Each suggestion has
1033+
a score and string.
1034+
If kwargs["increment"] is true and the terms are already in the
1035+
server's dictionary, we increment their scores.
1036+
1037+
For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugadd
1038+
""" # noqa
1039+
# If Transaction is not False it will MULTI/EXEC which will error
1040+
pipe = self.pipeline(transaction=False)
1041+
for sug in suggestions:
1042+
args = [SUGADD_COMMAND, key, sug.string, sug.score]
1043+
if kwargs.get("increment"):
1044+
args.append("INCR")
1045+
if sug.payload:
1046+
args.append("PAYLOAD")
1047+
args.append(sug.payload)
1048+
1049+
pipe.execute_command(*args)
1050+
1051+
return (await pipe.execute())[-1]
1052+
1053+
async def sugget(
1054+
self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False
1055+
):
1056+
"""
1057+
Get a list of suggestions from the AutoCompleter, for a given prefix.
1058+
1059+
Parameters:
1060+
1061+
prefix : str
1062+
The prefix we are searching. **Must be valid ascii or utf-8**
1063+
fuzzy : bool
1064+
If set to true, the prefix search is done in fuzzy mode.
1065+
**NOTE**: Running fuzzy searches on short (<3 letters) prefixes
1066+
can be very
1067+
slow, and even scan the entire index.
1068+
with_scores : bool
1069+
If set to true, we also return the (refactored) score of
1070+
each suggestion.
1071+
This is normally not needed, and is NOT the original score
1072+
inserted into the index.
1073+
with_payloads : bool
1074+
Return suggestion payloads
1075+
num : int
1076+
The maximum number of results we return. Note that we might
1077+
return less. The algorithm trims irrelevant suggestions.
1078+
1079+
Returns:
1080+
1081+
list:
1082+
A list of Suggestion objects. If with_scores was False, the
1083+
score of all suggestions is 1.
1084+
1085+
For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugget
1086+
""" # noqa
1087+
args = [SUGGET_COMMAND, key, prefix, "MAX", num]
1088+
if fuzzy:
1089+
args.append(FUZZY)
1090+
if with_scores:
1091+
args.append(WITHSCORES)
1092+
if with_payloads:
1093+
args.append(WITHPAYLOADS)
1094+
1095+
ret = await self.execute_command(*args)
1096+
results = []
1097+
if not ret:
1098+
return results
1099+
1100+
parser = SuggestionParser(with_scores, with_payloads, ret)
1101+
return [s for s in parser]

tests/test_asyncio/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def wait_for_command(
202202
# generate key
203203
redis_version = REDIS_INFO["version"]
204204
if Version(redis_version) >= Version("5.0.0"):
205-
id_str = str(client.client_id())
205+
id_str = str(await client.client_id())
206206
else:
207207
id_str = f"{random.randrange(2 ** 32):08x}"
208208
key = f"__REDIS-PY-{id_str}__"

0 commit comments

Comments
 (0)