Skip to content

Commit

Permalink
[client][placement groups] Client placement group hooks, attempt #3 (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitriGekhtman authored Apr 23, 2021
1 parent af01a47 commit 0d0c241
Show file tree
Hide file tree
Showing 11 changed files with 1,094 additions and 791 deletions.
33 changes: 31 additions & 2 deletions python/ray/_private/client_mode_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ def client_mode_should_convert():
return client_mode_enabled and _client_hook_enabled


def client_mode_wrap(func):
"""Wraps a function called during client mode for execution as a remote
task.
Can be used to implement public features of ray client which do not
belong in the main ray API (`ray.*`), yet require server-side execution.
An example is the creation of placement groups:
`ray.util.placement_group.placement_group()`. When called on the client
side, this function is wrapped in a task to facilitate interaction with
the GCS.
"""
from ray.util.client import ray

@wraps(func)
def wrapper(*args, **kwargs):
if client_mode_should_convert():
f = ray.remote(num_cpus=0)(func)
ref = f.remote(*args, **kwargs)
return ray.get(ref)
return func(*args, **kwargs)

return wrapper


def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
"""Runs a preregistered ray RemoteFunction through the ray client.
Expand All @@ -80,7 +104,10 @@ def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
from ray.util.client import ray

key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
if key is None:

# Second part of "or" is needed in case func_cls is reused between Ray
# client sessions in one Python interpreter session.
if (key is None) or (not ray._converted_key_exists(key)):
key = ray._convert_function(func_cls)
setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
client_func = ray._get_converted(key)
Expand All @@ -98,7 +125,9 @@ def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
from ray.util.client import ray

key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
if key is None:
# Second part of "or" is needed in case actor_cls is reused between Ray
# client sessions in one Python interpreter session.
if (key is None) or (not ray._converted_key_exists(key)):
key = ray._convert_actor(actor_cls)
setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
client_actor = ray._get_converted(key)
Expand Down
7 changes: 7 additions & 0 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def ray_start_cluster(request):
yield res


@pytest.fixture
def ray_start_cluster_init(request):
param = getattr(request, "param", {})
with _ray_start_cluster(do_init=True, **param) as res:
yield res


@pytest.fixture
def ray_start_cluster_head(request):
param = getattr(request, "param", {})
Expand Down
51 changes: 41 additions & 10 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import _thread

import ray.util.client.server.server as ray_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import connect_to_client_or_not
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import _explicitly_enable_client_mode
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import enable_client_mode


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
Expand Down Expand Up @@ -179,6 +182,8 @@ def test_wait(ray_start_regular_shared):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_remote_functions(ray_start_regular_shared):
with ray_start_client_server() as ray:
SignalActor = create_remote_signal_actor(ray)
signaler = SignalActor.remote()

@ray.remote
def plus2(x):
Expand Down Expand Up @@ -220,6 +225,18 @@ def fact(x):
all_vals = ray.get(res[0])
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]

# Timeout 0 on ray.wait leads to immediate return
# (not indefinite wait for first return as with timeout None):
unready_ref = signaler.wait.remote()
res = ray.wait([unready_ref], timeout=0)
# Not ready.
assert res[0] == [] and len(res[1]) == 1
ray.get(signaler.send.remote())
ready_ref = signaler.wait.remote()
# Ready.
res = ray.wait([ready_ref], timeout=10)
assert len(res[0]) == 1 and res[1] == []


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_function_calling_function(ray_start_regular_shared):
Expand Down Expand Up @@ -523,16 +540,16 @@ def test_client_gpu_ids(call_ray_stop_only):
import ray
ray.init(num_cpus=2)

_explicitly_enable_client_mode()
# No client connection.
with pytest.raises(Exception) as e:
ray.get_gpu_ids()
assert str(e.value) == "Ray Client is not connected."\
" Please connect by calling `ray.connect`."
with enable_client_mode():
# No client connection.
with pytest.raises(Exception) as e:
ray.get_gpu_ids()
assert str(e.value) == "Ray Client is not connected."\
" Please connect by calling `ray.connect`."

with ray_start_client_server():
# Now have a client connection.
assert ray.get_gpu_ids() == []
with ray_start_client_server():
# Now have a client connection.
assert ray.get_gpu_ids() == []


def test_client_serialize_addon(call_ray_stop_only):
Expand All @@ -548,5 +565,19 @@ class User(pydantic.BaseModel):
assert ray.get(ray.put(User(name="ray"))).name == "ray"


@pytest.mark.parametrize("connect_to_client", [False, True])
def test_client_context_manager(ray_start_regular_shared, connect_to_client):
import ray
with connect_to_client_or_not(connect_to_client):
if connect_to_client:
# Client mode is on.
assert client_mode_should_convert() is True
# We're connected to Ray client.
assert ray.util.client.ray.is_connected() is True
else:
assert client_mode_should_convert() is False
assert ray.util.client.ray.is_connected() is False


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
31 changes: 19 additions & 12 deletions python/ray/tests/test_client_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,26 @@ def test_validate_port():


def test_basic_preregister(init_and_serve):
"""Tests conversion of Ray actors and remote functions to client actors
and client remote functions.
Checks that the conversion works when disconnecting and reconnecting client
sessions.
"""
from ray.util.client import ray
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()
for _ in range(2):
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()


def test_num_clients(init_and_serve_lazy):
Expand Down
Loading

0 comments on commit 0d0c241

Please sign in to comment.