55import warnings
66from typing import (
77 Any ,
8+ Callable ,
89 Deque ,
910 Dict ,
1011 Generator ,
1112 List ,
1213 Mapping ,
1314 Optional ,
15+ Tuple ,
1416 Type ,
1517 TypeVar ,
1618 Union ,
@@ -250,6 +252,7 @@ def __init__(
250252 ssl_certfile : Optional [str ] = None ,
251253 ssl_check_hostname : bool = False ,
252254 ssl_keyfile : Optional [str ] = None ,
255+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
253256 ) -> None :
254257 if db :
255258 raise RedisClusterException (
@@ -337,7 +340,12 @@ def __init__(
337340 if host and port :
338341 startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
339342
340- self .nodes_manager = NodesManager (startup_nodes , require_full_coverage , kwargs )
343+ self .nodes_manager = NodesManager (
344+ startup_nodes ,
345+ require_full_coverage ,
346+ kwargs ,
347+ host_port_remap = host_port_remap ,
348+ )
341349 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
342350 self .read_from_replicas = read_from_replicas
343351 self .reinitialize_steps = reinitialize_steps
@@ -1059,17 +1067,20 @@ class NodesManager:
10591067 "require_full_coverage" ,
10601068 "slots_cache" ,
10611069 "startup_nodes" ,
1070+ "host_port_remap" ,
10621071 )
10631072
10641073 def __init__ (
10651074 self ,
10661075 startup_nodes : List ["ClusterNode" ],
10671076 require_full_coverage : bool ,
10681077 connection_kwargs : Dict [str , Any ],
1078+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
10691079 ) -> None :
10701080 self .startup_nodes = {node .name : node for node in startup_nodes }
10711081 self .require_full_coverage = require_full_coverage
10721082 self .connection_kwargs = connection_kwargs
1083+ self .host_port_remap = host_port_remap
10731084
10741085 self .default_node : "ClusterNode" = None
10751086 self .nodes_cache : Dict [str , "ClusterNode" ] = {}
@@ -1228,6 +1239,7 @@ async def initialize(self) -> None:
12281239 if host == "" :
12291240 host = startup_node .host
12301241 port = int (primary_node [1 ])
1242+ host , port = self .remap_host_port (host , port )
12311243
12321244 target_node = tmp_nodes_cache .get (get_node_name (host , port ))
12331245 if not target_node :
@@ -1246,6 +1258,7 @@ async def initialize(self) -> None:
12461258 for replica_node in replica_nodes :
12471259 host = replica_node [0 ]
12481260 port = replica_node [1 ]
1261+ host , port = self .remap_host_port (host , port )
12491262
12501263 target_replica_node = tmp_nodes_cache .get (
12511264 get_node_name (host , port )
@@ -1319,6 +1332,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
13191332 )
13201333 )
13211334
1335+ def remap_host_port (self , host : str , port : int ) -> Tuple [str , int ]:
1336+ """
1337+ Remap the host and port returned from the cluster to a different
1338+ internal value. Useful if the client is not connecting directly
1339+ to the cluster.
1340+ """
1341+ if self .host_port_remap :
1342+ return self .host_port_remap (host , port )
1343+ return host , port
1344+
13221345
13231346class ClusterPipeline (AbstractRedis , AbstractRedisCluster , AsyncRedisClusterCommands ):
13241347 """
0 commit comments