diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 64f36958d0..55aa82c842 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -321,6 +321,22 @@ def get_logs(self, cluster=True, scheduler=True, workers=True): def logs(self, *args, **kwargs): return self.get_logs(*args, **kwargs) + def get_client(self): + """Return client for the cluster + + If a client has already been initialized for the cluster, return that + otherwise initialize a new client object. + """ + from distributed.client import Client + + try: + current_client = Client.current() + if current_client and current_client.cluster == self: + return current_client + except ValueError: + pass + return Client(self) + @property def dashboard_link(self): try: diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index ad354a16fa..3c4a889cee 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1267,3 +1267,14 @@ def test_localcluster_start_exception(loop): loop=loop, ): pass + + +def test_localcluster_get_client(): + with LocalCluster( + n_workers=0, asynchronous=False, dashboard_address=":0" + ) as cluster: + client1 = cluster.get_client() + assert client1.cluster == cluster + client2 = Client(cluster) + assert client1 != client2 + assert client2 == cluster.get_client()