Skip to content

Commit

Permalink
Add support for SEARCH commands in cluster (#2042)
Browse files Browse the repository at this point in the history
* Add support for SEARCH commands in cluster

* delete json tests mark & list search commands

* linters
  • Loading branch information
dvora-h authored Mar 14, 2022
1 parent fdf4f1a commit b442110
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 37 deletions.
47 changes: 41 additions & 6 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,42 @@ class RedisCluster(RedisClusterCommands):
),
)

SEARCH_COMMANDS = (
[
"FT.CREATE",
"FT.SEARCH",
"FT.AGGREGATE",
"FT.EXPLAIN",
"FT.EXPLAINCLI",
"FT,PROFILE",
"FT.ALTER",
"FT.DROPINDEX",
"FT.ALIASADD",
"FT.ALIASUPDATE",
"FT.ALIASDEL",
"FT.TAGVALS",
"FT.SUGADD",
"FT.SUGGET",
"FT.SUGDEL",
"FT.SUGLEN",
"FT.SYNUPDATE",
"FT.SYNDUMP",
"FT.SPELLCHECK",
"FT.DICTADD",
"FT.DICTDEL",
"FT.DICTDUMP",
"FT.INFO",
"FT._LIST",
"FT.CONFIG",
"FT.ADD",
"FT.DEL",
"FT.DROP",
"FT.GET",
"FT.MGET",
"FT.SYNADD",
],
)

CLUSTER_COMMANDS_RESPONSE_CALLBACKS = {
"CLUSTER ADDSLOTS": bool,
"CLUSTER ADDSLOTSRANGE": bool,
Expand Down Expand Up @@ -854,6 +890,8 @@ def _determine_nodes(self, *args, **kwargs):
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
Expand Down Expand Up @@ -1956,17 +1994,14 @@ def _send_cluster_commands(
# refer to our internal node -> slot table that
# tells us where a given
# command should route to.
slot = self.determine_slot(*c.args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and c.args[0] in READ_COMMANDS
)
node = self._determine_nodes(*c.args)

# now that we know the name of the node
# ( it's just a string in the form of host:port )
# we can build a list of commands for each node.
node_name = node.name
node_name = node[0].name
if node_name not in nodes:
redis_node = self.get_redis_connection(node)
redis_node = self.get_redis_connection(node[0])
connection = get_connection(redis_node, c.args)
nodes[node_name] = NodeCommands(
redis_node.parse_response, redis_node.connection_pool, connection
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60):
while now < end_time:
try:
client = redis.RedisCluster.from_url(redis_url)
if len(client.get_nodes()) == cluster_nodes:
if len(client.get_nodes()) == int(cluster_nodes):
print("All nodes are available!")
break
except RedisClusterException:
Expand Down
56 changes: 26 additions & 30 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import redis.commands.search
import redis.commands.search.aggregation as aggregations
import redis.commands.search.reducers as reducers
from redis import Redis
from redis.commands.json.path import Path
from redis.commands.search import Search
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
Expand All @@ -19,10 +18,7 @@
from redis.commands.search.result import Result
from redis.commands.search.suggestion import Suggestion

from .conftest import default_redismod_url, skip_ifmodversion_lt

pytestmark = pytest.mark.onlynoncluster

from .conftest import skip_ifmodversion_lt

WILL_PLAY_TEXT = os.path.abspath(
os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2")
Expand All @@ -36,7 +32,7 @@
def waitForIndex(env, idx, timeout=None):
delay = 0.1
while True:
res = env.execute_command("ft.info", idx)
res = env.execute_command("FT.INFO", idx)
try:
res.index("indexing")
except ValueError:
Expand All @@ -52,13 +48,12 @@ def waitForIndex(env, idx, timeout=None):
break


def getClient():
def getClient(client):
"""
Gets a client client attached to an index name which is ready to be
created
"""
rc = Redis.from_url(default_redismod_url, decode_responses=True)
return rc
return client


def createIndex(client, num_docs=100, definition=None):
Expand Down Expand Up @@ -96,12 +91,6 @@ def createIndex(client, num_docs=100, definition=None):
indexer.commit()


# override the default module client, search requires both db=0, and text
@pytest.fixture
def modclient():
return Redis.from_url(default_redismod_url, db=0, decode_responses=True)


@pytest.fixture
def client(modclient):
modclient.flushdb()
Expand Down Expand Up @@ -234,6 +223,7 @@ def test_payloads(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_scores(client):
client.ft().create_index((TextField("txt"),))

Expand Down Expand Up @@ -356,14 +346,14 @@ def test_sort_by(client):

@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_drop_index():
def test_drop_index(client):
"""
Ensure the index gets dropped by data remains by default
"""
for x in range(20):
for keep_docs in [[True, {}], [False, {"name": "haveit"}]]:
idx = "HaveIt"
index = getClient()
index = getClient(client)
index.hset("index:haveit", mapping={"name": "haveit"})
idef = IndexDefinition(prefix=["index:"])
index.ft(idx).create_index((TextField("name"),), definition=idef)
Expand Down Expand Up @@ -574,9 +564,9 @@ def test_summarize(client):

@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_alias():
index1 = getClient()
index2 = getClient()
def test_alias(client):
index1 = getClient(client)
index2 = getClient(client)

def1 = IndexDefinition(prefix=["index1:"])
def2 = IndexDefinition(prefix=["index2:"])
Expand All @@ -594,7 +584,7 @@ def test_alias():

# create alias and check for results
ftindex1.aliasadd("spaceballs")
alias_client = getClient().ft("spaceballs")
alias_client = getClient(client).ft("spaceballs")
res = alias_client.search("*").docs[0]
assert "index1:lonestar" == res.id

Expand All @@ -604,7 +594,7 @@ def test_alias():

# update alias and ensure new results
ftindex2.aliasupdate("spaceballs")
alias_client2 = getClient().ft("spaceballs")
alias_client2 = getClient(client).ft("spaceballs")

res = alias_client2.search("*").docs[0]
assert "index2:yogurt" == res.id
Expand All @@ -615,21 +605,21 @@ def test_alias():


@pytest.mark.redismod
def test_alias_basic():
def test_alias_basic(client):
# Creating a client with one index
getClient().flushdb()
index1 = getClient().ft("testAlias")
getClient(client).flushdb()
index1 = getClient(client).ft("testAlias")

index1.create_index((TextField("txt"),))
index1.add_document("doc1", txt="text goes here")

index2 = getClient().ft("testAlias2")
index2 = getClient(client).ft("testAlias2")
index2.create_index((TextField("txt"),))
index2.add_document("doc2", txt="text goes here")

# add the actual alias and check
index1.aliasadd("myalias")
alias_client = getClient().ft("myalias")
alias_client = getClient(client).ft("myalias")
res = sorted(alias_client.search("*").docs, key=lambda x: x.id)
assert "doc1" == res[0].id

Expand All @@ -639,7 +629,7 @@ def test_alias_basic():

# update the alias and ensure we get doc2
index2.aliasupdate("myalias")
alias_client2 = getClient().ft("myalias")
alias_client2 = getClient(client).ft("myalias")
res = sorted(alias_client2.search("*").docs, key=lambda x: x.id)
assert "doc1" == res[0].id

Expand Down Expand Up @@ -790,6 +780,7 @@ def test_phonetic_matcher(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_scorer(client):
client.ft().create_index((TextField("description"),))

Expand Down Expand Up @@ -842,6 +833,7 @@ def test_get(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_ifmodversion_lt("2.2.0", "search")
def test_config(client):
assert client.ft().config_set("TIMEOUT", "100")
Expand All @@ -854,6 +846,7 @@ def test_config(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_aggregations_groupby(client):
# Creating the index definition and schema
client.ft().create_index(
Expand Down Expand Up @@ -1085,8 +1078,8 @@ def test_aggregations_apply(client):
CreatedDateTimeUTC="@CreatedDateTimeUTC * 10"
)
res = client.ft().aggregate(req)
assert res.rows[0] == ["CreatedDateTimeUTC", "6373878785249699840"]
assert res.rows[1] == ["CreatedDateTimeUTC", "6373878758592700416"]
res_set = set([res.rows[0][1], res.rows[1][1]])
assert res_set == set(["6373878785249699840", "6373878758592700416"])


@pytest.mark.redismod
Expand Down Expand Up @@ -1158,6 +1151,7 @@ def test_index_definition(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def testExpire(client):
client.ft().create_index((TextField("txt", sortable=True),), temporary=4)
ttl = client.execute_command("ft.debug", "TTL", "idx")
Expand Down Expand Up @@ -1477,6 +1471,7 @@ def test_json_with_jsonpath(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_profile(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
Expand Down Expand Up @@ -1505,6 +1500,7 @@ def test_profile(client):


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_profile_limited(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
Expand Down

0 comments on commit b442110

Please sign in to comment.