diff --git a/redis/cluster.py b/redis/cluster.py index 3ecc2dab56..2ab173ded9 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -113,6 +113,13 @@ def parse_cluster_shards(resp, **options): return shards +def parse_cluster_myshardid(resp, **options): + """ + Parse CLUSTER MYSHARDID response. + """ + return resp.decode("utf-8") + + PRIMARY = "primary" REPLICA = "replica" SLOT_ID = "slot-id" @@ -341,6 +348,7 @@ class AbstractRedisCluster: CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { "CLUSTER SLOTS": parse_cluster_slots, "CLUSTER SHARDS": parse_cluster_shards, + "CLUSTER MYSHARDID": parse_cluster_myshardid, } RESULT_CALLBACKS = dict_merge( diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index a23a94a3d3..cd93a85aba 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: from redis.asyncio.cluster import TargetNodesT - # Not complete, but covers the major ones # https://redis.io/commands READ_COMMANDS = frozenset( @@ -634,6 +633,14 @@ def cluster_shards(self, target_nodes=None): """ return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes) + def cluster_myshardid(self, target_nodes=None): + """ + Returns the shard ID of the node. + + For more information see https://redis.io/commands/cluster-myshardid/ + """ + return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes) + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 2d6099f6a9..17aa879b0f 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1006,6 +1006,13 @@ async def test_cluster_myid(self, r: RedisCluster) -> None: myid = await r.cluster_myid(node) assert len(myid) == 40 + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + async def test_cluster_myshardid(self, r: RedisCluster) -> None: + node = r.get_random_node() + myshardid = await r.cluster_myshardid(node) + assert len(myshardid) == 40 + @skip_if_redis_enterprise() async def test_cluster_slots(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, default_cluster_slots) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 8371cc577f..705e753bd6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1162,6 +1162,13 @@ def test_cluster_shards(self, r): for attribute in node.keys(): assert attribute in attributes + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + def test_cluster_myshardid(self, r): + myshardid = r.cluster_myshardid() + assert isinstance(myshardid, str) + assert len(myshardid) > 0 + @skip_if_redis_enterprise() def test_cluster_addslots(self, r): node = r.get_random_node()