Skip to content

Commit

Permalink
[ray_client] close ray connection upon client deactivation (#13919)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Feb 7, 2021
1 parent 4b49414 commit 3a230fa
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 152 deletions.
1 change: 1 addition & 0 deletions ci/travis/ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ test_python() {
-python/ray/tests:test_basic_3 # timeout
-python/ray/tests:test_basic_3_client_mode
-python/ray/tests:test_cli
-python/ray/tests:test_client_init # timeout
-python/ray/tests:test_failure
-python/ray/tests:test_global_gc
-python/ray/tests:test_job
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ py_test_module_list(
"test_basic_3.py",
"test_cancel.py",
"test_cli.py",
"test_client.py",
"test_client_init.py",
"test_component_failures_2.py",
"test_component_failures_3.py",
"test_error_ray_not_initialized.py",
Expand Down Expand Up @@ -80,9 +82,7 @@ py_test_module_list(
"test_asyncio.py",
"test_autoscaler.py",
"test_autoscaler_yaml.py",
"test_client_init.py",
"test_client_metadata.py",
"test_client.py",
"test_client_references.py",
"test_client_terminate.py",
"test_command_runner.py",
Expand Down
260 changes: 138 additions & 122 deletions python/ray/tests/test_client_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,130 +38,146 @@ def get(self):
return self.val


def test_basic_preregister():
@pytest.fixture
def init_and_serve():
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


@pytest.fixture
def init_and_serve_lazy():
cluster = ray.cluster_utils.Cluster()
cluster.add_node(num_cpus=1, num_gpus=0)
address = cluster.address

def connect():
ray.init(address=address)

server_handle = ray_client_server.serve("localhost:50051", connect)
yield server_handle
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


def test_basic_preregister(init_and_serve):
from ray.util.client import ray
server, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
finally:
ray.disconnect()
ray_client_server.shutdown_with_server(server)
time.sleep(2)


def test_num_clients():
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()


def test_num_clients(init_and_serve_lazy):
# Tests num clients reporting; useful if you want to build an app that
# load balances clients between Ray client servers.
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
server = server_handle.grpc_server
try:
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
assert info2["num_clients"] == 2, info2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
assert info3["num_clients"] == 1, info3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()
finally:
ray_client_server.shutdown_with_server(server)
time.sleep(2)


def test_python_version():

server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)


def test_protocol_version():
def get_job_id(api):
return api.get_runtime_context().worker.current_job_id

server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join(
[str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
api1 = RayAPIStub()
info1 = api1.connect("localhost:50051")
job_id_1 = get_job_id(api1)
assert info1["num_clients"] == 1, info1
api2 = RayAPIStub()
info2 = api2.connect("localhost:50051")
job_id_2 = get_job_id(api2)
assert info2["num_clients"] == 2, info2

assert job_id_1 == job_id_2

# Disconnect the first two clients.
api1.disconnect()
api2.disconnect()
time.sleep(1)

api3 = RayAPIStub()
info3 = api3.connect("localhost:50051")
job_id_3 = get_job_id(api3)
assert info3["num_clients"] == 1, info3
assert job_id_1 != job_id_3

# Check info contains ray and python version.
assert isinstance(info3["ray_version"], str), info3
assert isinstance(info3["ray_commit"], str), info3
assert isinstance(info3["python_version"], str), info3
assert isinstance(info3["protocol_version"], str), info3
api3.disconnect()


def test_python_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
assert info1["python_version"] == ".".join(
[str(x) for x in list(sys.version_info)[:3]])
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()


def test_protocol_version(init_and_serve):
server_handle = init_and_serve
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)

def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)

# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response

ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")

ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()


if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
4 changes: 2 additions & 2 deletions python/ray/tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self):
assert len(actor_table) == 1

job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash

# Kill the driver process.
p.kill()
Expand Down Expand Up @@ -79,7 +79,7 @@ def value(self):
assert len(actor_table) == 1

job_table = ray.jobs()
assert len(job_table) == 3 # dash, ray client server
assert len(job_table) == 2 # dash

# Kill the driver process.
p.kill()
Expand Down
15 changes: 13 additions & 2 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import grpc
import sys

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from threading import Lock

import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import CURRENT_PROTOCOL_VERSION
from ray._private.client_mode_hook import disable_client_hook

if TYPE_CHECKING:
from ray.util.client.server.server import RayletServicer
Expand All @@ -17,10 +18,12 @@


class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
def __init__(self, basic_service: "RayletServicer",
ray_connect_handler: Callable):
self.basic_service = basic_service
self._clients_lock = Lock()
self._num_clients = 0 # guarded by self._clients_lock
self.ray_connect_handler = ray_connect_handler

def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
Expand All @@ -31,6 +34,9 @@ def Datapath(self, request_iterator, context):
logger.info(f"New data connection from client {client_id}")
try:
with self._clients_lock:
with disable_client_hook():
if self._num_clients == 0 and not ray.is_initialized():
self.ray_connect_handler()
self._num_clients += 1
for req in request_iterator:
resp = None
Expand Down Expand Up @@ -63,9 +69,14 @@ def Datapath(self, request_iterator, context):
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)

with self._clients_lock:
self._num_clients -= 1

with disable_client_hook():
if self._num_clients == 0:
ray.shutdown()

def _build_connection_response(self):
with self._clients_lock:
cur_num_clients = self._num_clients
Expand Down
Loading

0 comments on commit 3a230fa

Please sign in to comment.