Skip to content

Commit

Permalink
convert cluster client to property
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Jul 12, 2024
1 parent 7501e09 commit 5071c04
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self._creds = creds

self.ips = ips
self.client = None
self._http_client = None
self.den_auth = den_auth or False
self.cert_config = TLSCertConfig(cert_path=ssl_certfile, key_path=ssl_keyfile)

Expand All @@ -117,6 +117,10 @@ def address(self, addr):
self.ips = self.ips or [None]
self.ips[0] = addr

@property
def client(self):
return self._http_client

@property
def creds_values(self) -> Dict:
if not self._creds:
Expand Down Expand Up @@ -616,7 +620,7 @@ def connect_server_client(self, force_reconnect=False):
# Connecting to localhost because it's tunneled into the server at the specified port.
# As long as the tunnel was initialized,
# self.client_port has been set to the correct port
self.client = HTTPClient(
self._http_client = HTTPClient(
host=LOCALHOST,
port=self.client_port,
resource_address=self.rns_address,
Expand All @@ -640,7 +644,7 @@ def connect_server_client(self, force_reconnect=False):

self.client_port = self.client_port or self.server_port

self.client = HTTPClient(
self._http_client = HTTPClient(
host=self.server_address,
port=self.client_port,
cert_path=cert_path,
Expand Down Expand Up @@ -1028,7 +1032,7 @@ def disconnect(self):
def __getstate__(self):
"""Delete non-serializable elements (e.g. thread locks) before pickling."""
state = self.__dict__.copy()
state["client"] = None
state["_http_client"] = None
state["_rpc_tunnel"] = None
return state

Expand Down
2 changes: 1 addition & 1 deletion runhouse/resources/hardware/sagemaker/sagemaker_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __getstate__(self):
"""Delete non-serializable elements (e.g. sagemaker session object) before pickling."""
state = self.__dict__.copy()
state["_sagemaker_session"] = None
state["client"] = None
state["_http_client"] = None
state["_rpc_tunnel"] = None
return state

Expand Down

0 comments on commit 5071c04

Please sign in to comment.