Skip to content

Commit a95ac04

Browse files
israbbaniedoakes
authored andcommitted
[core] (ray-get 2/n) Making ray.get thread-safe (#57911)
This PR makes the `ray.get` public API thread-safe. It also cleans up a lot of tech-debt wrt to * Workers yielding CPU to the raylet when blocked. * Cleaning up finished/inflight Get requests. Previously, the raylet coalesced all get requests from the same worker into one Get (and Pull) request. However, Get request cleanup could happen on multiple threads meaning **one thread could cancel inflight get requests for all threads in a worker**. This issue was reported in #54007. ### Changes in this PR: Raylet (server-side) 1. AsyncGetObjects will return a request_id. 2. LeaseDependencyManager no longer coalesces AsyncGetObjects requests from the same worker. 3. LeaseDependencyManager has two methods for cleanup (delete all requests for worker during worker disconnect/lease cleanup) and delete a specific request (called through CancelGetRequest) 4. Wait no longer cancels all Get requests for the worker (this was probably a bug) 5. NotifyWorkerUnblock does not cancel get requests anymore. CoreWorker (client-side) 1. PlasmaStoreProvider::Get will make 1 call to AsyncGetObjects per batch. 2. PlasmaStoreProvider::Get will store scoped cleanup handlers that will call CancelGetRequest for each call to AsyncGetObjects to guarantee RAII-style cleanup Closes #54007. --------- Signed-off-by: irabbani <israbbani@gmail.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com> Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
1 parent 6d33dd3 commit a95ac04

16 files changed

+418
-294
lines changed

python/ray/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,7 @@ py_test_module_list(
884884
"test_dataclient_disconnect.py",
885885
"test_iter.py",
886886
"test_placement_group.py",
887+
"test_ray_get.py",
887888
"test_state_api_2.py",
888889
"test_task_events.py",
889890
"test_unavailable_actors.py",

python/ray/tests/test_draining.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,21 @@ def ping(self):
427427

428428
# Simulate autoscaler terminates the worker node after the draining deadline.
429429
cluster.remove_node(node2, graceful)
430-
try:
431-
ray.get(actor.ping.remote())
432-
raise
433-
except ray.exceptions.ActorDiedError as e:
434-
assert e.preempted
435-
if graceful:
436-
assert "The actor died because its node has died." in str(e)
437-
assert "the actor's node was preempted: " + drain_reason_message in str(e)
430+
431+
def check_actor_died_error():
432+
try:
433+
ray.get(actor.ping.remote())
434+
return False
435+
except ray.exceptions.ActorDiedError as e:
436+
assert e.preempted
437+
if graceful:
438+
assert "The actor died because its node has died." in str(e)
439+
assert "the actor's node was preempted: " + drain_reason_message in str(
440+
e
441+
)
442+
return True
443+
444+
wait_for_condition(check_actor_died_error)
438445

439446

440447
def test_drain_node_actor_restart(ray_start_cluster):

python/ray/tests/test_ray_get.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import sys
2+
import threading
3+
import time
4+
5+
import numpy as np
6+
import pytest
7+
8+
import ray
9+
10+
11+
def test_multithreaded_ray_get(ray_start_cluster):
12+
# This test tries to get a large object from the head node to the worker node
13+
# while making many concurrent ray.get requests for a local object in plasma.
14+
# TODO(57923): Make this not rely on timing if possible.
15+
ray_cluster = ray_start_cluster
16+
ray_cluster.add_node(
17+
# This will make the object transfer slower and allow the test to
18+
# interleave Get requests.
19+
_system_config={
20+
"object_manager_max_bytes_in_flight": 1024**2,
21+
}
22+
)
23+
ray.init(address=ray_cluster.address)
24+
ray_cluster.add_node(resources={"worker": 1})
25+
26+
# max_concurrency >= 3 is required: one thread for small gets, one for large gets,
27+
# one for setting the threading.Events.
28+
@ray.remote(resources={"worker": 1}, max_concurrency=3)
29+
class Actor:
30+
def __init__(self):
31+
# ray.put will ensure that the object is in plasma
32+
# even if it's small.
33+
self._local_small_ref = ray.put("1")
34+
35+
# Used to check the thread running the small `ray.gets` has made at least
36+
# one API call successfully.
37+
self._small_gets_started = threading.Event()
38+
39+
# Used to tell the thread running small `ray.gets` to exit.
40+
self._stop_small_gets = threading.Event()
41+
42+
def small_gets_started(self):
43+
self._small_gets_started.wait()
44+
45+
def stop_small_gets(self):
46+
self._stop_small_gets.set()
47+
48+
def do_small_gets(self):
49+
while not self._stop_small_gets.is_set():
50+
ray.get(self._local_small_ref)
51+
time.sleep(0.01)
52+
self._small_gets_started.set()
53+
54+
def do_large_get(self, refs_to_get):
55+
remote_large_ref = refs_to_get[0]
56+
ray.get(remote_large_ref)
57+
58+
actor = Actor.remote()
59+
60+
# Start a task on one thread that will repeatedly call `ray.get` on small
61+
# plasma objects.
62+
small_gets_ref = actor.do_small_gets.remote()
63+
ray.get(actor.small_gets_started.remote())
64+
65+
# Start a second task on another thread that will call `ray.get` on a large object.
66+
# The transfer will be slow due to the system config set above.
67+
large_ref = ray.put(np.ones(1024**3, dtype=np.int8))
68+
ray.get(actor.do_large_get.remote([large_ref]))
69+
70+
# Check that all `ray.get` calls succeeded.
71+
ray.get(actor.stop_small_gets.remote())
72+
ray.get(small_gets_ref)
73+
74+
75+
if __name__ == "__main__":
76+
sys.exit(pytest.main(["-sv", __file__]))

src/ray/core_worker/store_provider/plasma_store_provider.cc

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ray/core_worker/store_provider/plasma_store_provider.h"
1616

1717
#include <algorithm>
18+
#include <cstdint>
1819
#include <memory>
1920
#include <string>
2021
#include <utility>
@@ -177,23 +178,20 @@ Status CoreWorkerPlasmaStoreProvider::Release(const ObjectID &object_id) {
177178
return store_client_->Release(object_id);
178179
}
179180

180-
Status CoreWorkerPlasmaStoreProvider::PullObjectsAndGetFromPlasmaStore(
181+
Status CoreWorkerPlasmaStoreProvider::GetObjectsFromPlasmaStore(
181182
absl::flat_hash_set<ObjectID> &remaining,
182-
const std::vector<ObjectID> &batch_ids,
183+
const std::vector<ObjectID> &ids,
183184
int64_t timeout_ms,
184185
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
185186
bool *got_exception) {
186-
const auto owner_addresses = reference_counter_.GetOwnerAddresses(batch_ids);
187-
RAY_RETURN_NOT_OK(raylet_ipc_client_->AsyncGetObjects(batch_ids, owner_addresses));
188-
189187
std::vector<plasma::ObjectBuffer> plasma_results;
190-
RAY_RETURN_NOT_OK(store_client_->Get(batch_ids, timeout_ms, &plasma_results));
188+
RAY_RETURN_NOT_OK(store_client_->Get(ids, timeout_ms, &plasma_results));
191189

192190
// Add successfully retrieved objects to the result map and remove them from
193191
// the set of IDs to get.
194192
for (size_t i = 0; i < plasma_results.size(); i++) {
195193
if (plasma_results[i].data != nullptr || plasma_results[i].metadata != nullptr) {
196-
const auto &object_id = batch_ids[i];
194+
const auto &object_id = ids[i];
197195
std::shared_ptr<TrackedBuffer> data = nullptr;
198196
std::shared_ptr<Buffer> metadata = nullptr;
199197
if (plasma_results[i].data && plasma_results[i].data->Size() > 0) {
@@ -216,7 +214,6 @@ Status CoreWorkerPlasmaStoreProvider::PullObjectsAndGetFromPlasmaStore(
216214
(*results)[object_id] = std::move(result_object);
217215
}
218216
}
219-
220217
return Status::OK();
221218
}
222219

@@ -258,37 +255,48 @@ Status CoreWorkerPlasmaStoreProvider::Get(
258255
const absl::flat_hash_set<ObjectID> &object_ids,
259256
int64_t timeout_ms,
260257
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results) {
261-
std::vector<ObjectID> batch_ids;
262-
absl::flat_hash_set<ObjectID> remaining(object_ids.begin(), object_ids.end());
258+
std::vector<ipc::ScopedResponse> get_request_cleanup_handlers;
263259

264-
// Send initial requests to pull all objects in parallel.
265-
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
266-
int64_t total_size = static_cast<int64_t>(object_ids.size());
267260
bool got_exception = false;
268-
for (int64_t start = 0; start < total_size; start += fetch_batch_size_) {
261+
absl::flat_hash_set<ObjectID> remaining(object_ids.begin(), object_ids.end());
262+
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
263+
std::vector<ObjectID> batch_ids;
264+
265+
int64_t num_total_objects = static_cast<int64_t>(object_ids.size());
266+
267+
// TODO(57923): Need to understand if batching is necessary. If it's necessary,
268+
// then the reason needs to be documented.
269+
for (int64_t start = 0; start < num_total_objects; start += fetch_batch_size_) {
269270
batch_ids.clear();
270-
for (int64_t i = start; i < start + fetch_batch_size_ && i < total_size; i++) {
271+
for (int64_t i = start; i < start + fetch_batch_size_ && i < num_total_objects; i++) {
271272
batch_ids.push_back(id_vector[i]);
272273
}
274+
275+
// 1. Make the request to pull all objects into local plasma if not local already.
276+
std::vector<rpc::Address> owner_addresses =
277+
reference_counter_.GetOwnerAddresses(batch_ids);
278+
StatusOr<ipc::ScopedResponse> status_or_cleanup =
279+
raylet_ipc_client_->AsyncGetObjects(batch_ids, owner_addresses);
280+
RAY_RETURN_NOT_OK(status_or_cleanup.status());
281+
get_request_cleanup_handlers.emplace_back(std::move(status_or_cleanup.value()));
282+
283+
// 2. Try to Get all objects that are already local from the plasma store.
273284
RAY_RETURN_NOT_OK(
274-
PullObjectsAndGetFromPlasmaStore(remaining,
275-
batch_ids,
276-
/*timeout_ms=*/0,
277-
// Mutable objects must be local before ray.get.
278-
results,
279-
&got_exception));
285+
GetObjectsFromPlasmaStore(remaining,
286+
batch_ids,
287+
/*timeout_ms=*/0,
288+
// Mutable objects must be local before ray.get.
289+
results,
290+
&got_exception));
280291
}
281292

282-
// If all objects were fetched successfully or if any of the returned
283-
// objects contain an exception, clean up the Get request in the raylet
284-
// and early exit.
285293
if (remaining.empty() || got_exception) {
286-
return raylet_ipc_client_->CancelGetRequest();
294+
return Status::OK();
287295
}
288296

289-
// If not all objects were successfully fetched, repeatedly call FetchOrReconstruct
290-
// and Get from the local object store in batches. This loop will run indefinitely
291-
// until the objects are all fetched if timeout is -1.
297+
// 3. If not all objects were successfully fetched, repeatedly call
298+
// GetObjectsFromPlasmaStore in batches. This loop will run indefinitely until the
299+
// objects are all fetched if timeout is -1.
292300
bool should_break = false;
293301
bool timed_out = false;
294302
int64_t remaining_timeout = timeout_ms;
@@ -312,7 +320,7 @@ Status CoreWorkerPlasmaStoreProvider::Get(
312320
}
313321

314322
size_t previous_size = remaining.size();
315-
RAY_RETURN_NOT_OK(PullObjectsAndGetFromPlasmaStore(
323+
RAY_RETURN_NOT_OK(GetObjectsFromPlasmaStore(
316324
remaining, batch_ids, batch_timeout, results, &got_exception));
317325
should_break = timed_out || got_exception;
318326

@@ -322,7 +330,6 @@ Status CoreWorkerPlasmaStoreProvider::Get(
322330
if (check_signals_) {
323331
Status status = check_signals_();
324332
if (!status.ok()) {
325-
RAY_RETURN_NOT_OK(raylet_ipc_client_->CancelGetRequest());
326333
return status;
327334
}
328335
}
@@ -337,11 +344,14 @@ Status CoreWorkerPlasmaStoreProvider::Get(
337344
}
338345

339346
if (!remaining.empty() && timed_out) {
340-
RAY_RETURN_NOT_OK(raylet_ipc_client_->CancelGetRequest());
341-
return Status::TimedOut("Get timed out: some object(s) not ready.");
347+
return Status::TimedOut(absl::StrFormat(
348+
"Could not fetch %d objects within the timeout of %dms. %d objects were not "
349+
"ready.",
350+
object_ids.size(),
351+
timeout_ms,
352+
remaining.size()));
342353
}
343-
344-
return raylet_ipc_client_->CancelGetRequest();
354+
return Status::OK();
345355
}
346356

347357
Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id,

src/ray/core_worker/store_provider/plasma_store_provider.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ class CoreWorkerPlasmaStoreProvider {
218218
std::shared_ptr<plasma::PlasmaClientInterface> &store_client() { return store_client_; }
219219

220220
private:
221-
/// Ask the raylet to pull a set of objects and then attempt to get them
222-
/// from the local plasma store. Successfully fetched objects will be removed
223-
/// from the input set of remaining IDs and added to the results map.
221+
/// Ask the plasma store to return object objects within the timeout.
222+
/// Successfully fetched objects will be removed from the input set of remaining IDs and
223+
/// added to the results map.
224224
///
225225
/// \param[in/out] remaining IDs of the remaining objects to get.
226-
/// \param[in] batch_ids IDs of the objects to get.
226+
/// \param[in] ids IDs of the objects to get.
227227
/// \param[in] timeout_ms Timeout in milliseconds.
228228
/// \param[out] results Map of objects to write results into. This method will only
229229
/// add to this map, not clear or remove from it, so the caller can pass in a non-empty
@@ -232,10 +232,10 @@ class CoreWorkerPlasmaStoreProvider {
232232
/// exception.
233233
/// \return Status::IOError if there is an error in communicating with the raylet or the
234234
/// plasma store.
235-
/// \return Status::OK otherwise.
236-
Status PullObjectsAndGetFromPlasmaStore(
235+
/// \return Status::OK if successful.
236+
Status GetObjectsFromPlasmaStore(
237237
absl::flat_hash_set<ObjectID> &remaining,
238-
const std::vector<ObjectID> &batch_ids,
238+
const std::vector<ObjectID> &ids,
239239
int64_t timeout_ms,
240240
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
241241
bool *got_exception);

src/ray/flatbuffers/node_manager.fbs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ enum MessageType:int {
3939
// The client should block until it receives this message before closing the socket.
4040
DisconnectClientReply,
4141
// Request the Raylet to pull a set of objects to the local node.
42-
// a raylet.
4342
AsyncGetObjectsRequest,
44-
// Cancel outstanding get requests from the worker.
43+
// Reply contains the request id that will be used to clean up the request.
44+
AsyncGetObjectsReply,
45+
// Cleanup a given get request on the raylet.
4546
CancelGetRequest,
4647
// Notify the current worker is blocked for objects to become available. The raylet
4748
// will release the worker's resources.
4849
NotifyWorkerBlocked,
49-
// Notify the current worker is unblocked. The raylet will cancel any inflight
50-
// pull requests for objects.
50+
// Notify the current worker is unblocked.
5151
NotifyWorkerUnblocked,
5252
// Wait for objects to be ready either from local or remote Plasma stores.
5353
WaitRequest,
@@ -136,7 +136,12 @@ table AsyncGetObjectsRequest {
136136
owner_addresses: [Address];
137137
}
138138

139+
table AsyncGetObjectsReply {
140+
request_id: long;
141+
}
142+
139143
table CancelGetRequest {
144+
request_id: long;
140145
}
141146

142147
table NotifyWorkerBlocked {

0 commit comments

Comments
 (0)