Skip to content

Commit 0b7e03b

Browse files
edoakeskamil-kaczmarek
authored andcommitted
[core] Rename NotifyUnblocked to CancelGetRequest (#55081)
The `NotifyUnblocked` naming is legacy from before Ray 1.0. Also removed `task_id` from various places that we didn't need it. Note there is an ongoing bugfix to cancel only the specific get request instead of all requests for the worker: #54495 --------- Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com> Signed-off-by: Kamil Kaczmarek <kamil@anyscale.com>
1 parent 936db47 commit 0b7e03b

File tree

6 files changed

+65
-129
lines changed

6 files changed

+65
-129
lines changed

src/ray/core_worker/store_provider/plasma_store_provider.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
184184
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
185185
bool *got_exception) {
186186
const auto owner_addresses = reference_counter_.GetOwnerAddresses(batch_ids);
187-
RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(
188-
batch_ids, owner_addresses, fetch_only, task_id));
187+
RAY_RETURN_NOT_OK(
188+
raylet_client_->FetchOrReconstruct(batch_ids, owner_addresses, fetch_only));
189189

190190
std::vector<plasma::ObjectBuffer> plasma_results;
191191
RAY_RETURN_NOT_OK(store_client_->Get(batch_ids,
@@ -273,7 +273,7 @@ Status UnblockIfNeeded(const std::shared_ptr<raylet::RayletClient> &client,
273273
return Status::OK(); // We don't need to release resources.
274274
}
275275
} else {
276-
return client->NotifyUnblocked(ctx.GetCurrentTaskID());
276+
return client->CancelGetRequest();
277277
}
278278
}
279279

@@ -403,12 +403,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait(
403403
}
404404

405405
const auto owner_addresses = reference_counter_.GetOwnerAddresses(id_vector);
406-
RAY_ASSIGN_OR_RETURN(ready_in_plasma,
407-
raylet_client_->Wait(id_vector,
408-
owner_addresses,
409-
num_objects,
410-
call_timeout,
411-
ctx.GetCurrentTaskID()));
406+
RAY_ASSIGN_OR_RETURN(
407+
ready_in_plasma,
408+
raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout));
412409

413410
if (ready_in_plasma.size() >= static_cast<size_t>(num_objects)) {
414411
should_break = true;

src/ray/raylet/format/node_manager.fbs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ enum MessageType:int {
4949
// Reconstruct or fetch possibly lost objects. This is sent from a worker to
5050
// a raylet.
5151
FetchOrReconstruct,
52-
// For a worker that was blocked on some object(s), tell the raylet
53-
// that the worker is now unblocked. This is sent from a worker to a raylet.
54-
NotifyUnblocked,
52+
// Cancel outstanding get requests from the worker.
53+
CancelGetRequest,
5554
// Notify the current worker is blocked. This is only used by direct task calls;
5655
// otherwise the block command is piggybacked on other messages.
5756
NotifyDirectCallTaskBlocked,
@@ -176,13 +175,9 @@ table FetchOrReconstruct {
176175
owner_addresses: [Address];
177176
// Do we only want to fetch the objects or also reconstruct them?
178177
fetch_only: bool;
179-
// The current task ID.
180-
task_id: string;
181178
}
182179

183-
table NotifyUnblocked {
184-
// The current task ID. This task is no longer blocked.
185-
task_id: string;
180+
table CancelGetRequest {
186181
}
187182

188183
table NotifyDirectCallTaskBlocked {
@@ -201,8 +196,6 @@ table WaitRequest {
201196
num_required_objects: int;
202197
// timeout
203198
timeout: long;
204-
// The current task ID.
205-
task_id: string;
206199
}
207200

208201
table WaitReply {

src/ray/raylet/node_manager.cc

Lines changed: 29 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,11 +1043,8 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
10431043
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
10441044
HandleDirectCallTaskUnblocked(registered_worker);
10451045
} break;
1046-
case protocol::MessageType::NotifyUnblocked: {
1047-
// TODO(ekl) this is still used from core worker even in direct call mode to
1048-
// finish up get requests.
1049-
auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
1050-
AsyncResolveObjectsFinish(client, from_flatbuf<TaskID>(*message->task_id()));
1046+
case protocol::MessageType::CancelGetRequest: {
1047+
CancelGetRequest(client);
10511048
} break;
10521049
case protocol::MessageType::WaitRequest: {
10531050
ProcessWaitRequestMessage(client, message_data);
@@ -1492,11 +1489,9 @@ void NodeManager::ProcessFetchOrReconstructMessage(
14921489
// subscribe to in the task dependency manager. These objects will be
14931490
// pulled from remote node managers. If an object's owner dies, an error
14941491
// will be stored as the object's value.
1495-
const TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
1496-
AsyncResolveObjects(client,
1497-
refs,
1498-
task_id,
1499-
/*ray_get=*/true);
1492+
AsyncGetOrWait(client,
1493+
refs,
1494+
/*is_get_request=*/true);
15001495
}
15011496
}
15021497

@@ -1508,28 +1503,24 @@ void NodeManager::ProcessWaitRequestMessage(
15081503
const auto refs =
15091504
FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses());
15101505

1511-
bool resolve_objects = false;
1506+
bool all_objects_local = true;
15121507
for (auto const &object_id : object_ids) {
15131508
if (!dependency_manager_.CheckObjectLocal(object_id)) {
1514-
// At least one object requires resolution.
1515-
resolve_objects = true;
1509+
all_objects_local = false;
15161510
}
15171511
}
15181512

1519-
const TaskID &current_task_id = from_flatbuf<TaskID>(*message->task_id());
1520-
if (resolve_objects) {
1513+
if (!all_objects_local) {
15211514
// Resolve any missing objects. This is a no-op for any objects that are
15221515
// already local. Missing objects will be pulled from remote node managers.
15231516
// If an object's owner dies, an error will be stored as the object's
15241517
// value.
1525-
AsyncResolveObjects(client,
1526-
refs,
1527-
current_task_id,
1528-
/*ray_get=*/false);
1518+
AsyncGetOrWait(client, refs, /*is_get_request=*/false);
15291519
}
1520+
15301521
if (message->num_required_objects() == 0) {
15311522
// If we don't need to wait for any, return immediately after making the pull
1532-
// requests through AsyncResolveObjects above.
1523+
// requests through AsyncGetOrWait above.
15331524
flatbuffers::FlatBufferBuilder fbb;
15341525
auto wait_reply = protocol::CreateWaitReply(fbb,
15351526
to_flatbuf(fbb, std::vector<ObjectID>{}),
@@ -1539,11 +1530,7 @@ void NodeManager::ProcessWaitRequestMessage(
15391530
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
15401531
fbb.GetSize(),
15411532
fbb.GetBufferPointer());
1542-
if (status.ok()) {
1543-
if (resolve_objects) {
1544-
AsyncResolveObjectsFinish(client, current_task_id);
1545-
}
1546-
} else {
1533+
if (!status.ok()) {
15471534
// We failed to write to the client, so disconnect the client.
15481535
std::ostringstream stream;
15491536
stream << "Failed to write WaitReply to the client. Status " << status;
@@ -1557,8 +1544,8 @@ void NodeManager::ProcessWaitRequestMessage(
15571544
object_ids,
15581545
message->timeout(),
15591546
num_required_objects,
1560-
[this, resolve_objects, client, current_task_id](std::vector<ObjectID> ready,
1561-
std::vector<ObjectID> remaining) {
1547+
[this, client, all_objects_local](std::vector<ObjectID> ready,
1548+
std::vector<ObjectID> remaining) {
15621549
// Write the data.
15631550
flatbuffers::FlatBufferBuilder fbb;
15641551
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
@@ -1570,10 +1557,8 @@ void NodeManager::ProcessWaitRequestMessage(
15701557
fbb.GetSize(),
15711558
fbb.GetBufferPointer());
15721559
if (status.ok()) {
1573-
// The client is unblocked now because the wait call has
1574-
// returned.
1575-
if (resolve_objects) {
1576-
AsyncResolveObjectsFinish(client, current_task_id);
1560+
if (!all_objects_local) {
1561+
CancelGetRequest(client);
15771562
}
15781563
} else {
15791564
// We failed to write to the client, so disconnect the client.
@@ -1589,19 +1574,14 @@ void NodeManager::ProcessWaitRequestMessage(
15891574

15901575
void NodeManager::ProcessWaitForActorCallArgsRequestMessage(
15911576
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
1592-
// Read the data.
15931577
auto message =
15941578
flatbuffers::GetRoot<protocol::WaitForActorCallArgsRequest>(message_data);
15951579
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids());
15961580
int64_t tag = message->tag();
1597-
// Resolve any missing objects. This will pull the objects from remote node
1598-
// managers or store an error if the objects have failed.
1581+
// Pull any missing objects to the local node.
15991582
const auto refs =
16001583
FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses());
1601-
AsyncResolveObjects(client,
1602-
refs,
1603-
TaskID::Nil(),
1604-
/*ray_get=*/false);
1584+
AsyncGetOrWait(client, refs, /*is_get_request=*/false);
16051585
// De-duplicate the object IDs.
16061586
absl::flat_hash_set<ObjectID> object_id_set(object_ids.begin(), object_ids.end());
16071587
object_ids.assign(object_id_set.begin(), object_id_set.end());
@@ -2096,44 +2076,31 @@ void NodeManager::HandleDirectCallTaskUnblocked(
20962076
}
20972077
}
20982078

2099-
void NodeManager::AsyncResolveObjects(
2100-
const std::shared_ptr<ClientConnection> &client,
2101-
const std::vector<rpc::ObjectReference> &required_object_refs,
2102-
const TaskID &current_task_id,
2103-
bool ray_get) {
2079+
void NodeManager::AsyncGetOrWait(const std::shared_ptr<ClientConnection> &client,
2080+
const std::vector<rpc::ObjectReference> &object_refs,
2081+
bool is_get_request) {
21042082
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
21052083
if (!worker) {
2106-
// The client is a driver. Drivers do not hold resources, so we simply mark
2107-
// the task as blocked.
21082084
worker = worker_pool_.GetRegisteredDriver(client);
21092085
}
2110-
21112086
RAY_CHECK(worker);
2112-
// Subscribe to the objects required by the task. These objects will be
2113-
// fetched and/or restarted as necessary, until the objects become local
2114-
// or are unsubscribed.
2115-
if (ray_get) {
2116-
dependency_manager_.StartOrUpdateGetRequest(worker->WorkerId(), required_object_refs);
2087+
2088+
// Start an async request to get or wait for the objects.
2089+
// The objects will be fetched locally unless the get or wait request is canceled.
2090+
if (is_get_request) {
2091+
dependency_manager_.StartOrUpdateGetRequest(worker->WorkerId(), object_refs);
21172092
} else {
2118-
dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(),
2119-
required_object_refs);
2093+
dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(), object_refs);
21202094
}
21212095
}
21222096

2123-
void NodeManager::AsyncResolveObjectsFinish(
2124-
const std::shared_ptr<ClientConnection> &client, const TaskID &current_task_id) {
2097+
void NodeManager::CancelGetRequest(const std::shared_ptr<ClientConnection> &client) {
21252098
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
21262099
if (!worker) {
2127-
// The client is a driver. Drivers do not hold resources, so we simply
2128-
// mark the driver as unblocked.
21292100
worker = worker_pool_.GetRegisteredDriver(client);
21302101
}
2131-
21322102
RAY_CHECK(worker);
2133-
// Unsubscribe from any `ray.get` objects that the task was blocked on. Any
2134-
// fetch or reconstruction operations to make the objects local are canceled.
2135-
// `ray.wait` calls will stay active until the objects become local, or the
2136-
// task/actor that called `ray.wait` exits.
2103+
21372104
dependency_manager_.CancelGetRequest(worker->WorkerId());
21382105
}
21392106

src/ray/raylet/node_manager.h

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -372,32 +372,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
372372
void FinishAssignedActorCreationTask(const std::shared_ptr<WorkerInterface> &worker,
373373
const RayTask &task);
374374

375-
/// Handle blocking gets of objects. This could be a task assigned to a worker,
376-
/// an out-of-band task (e.g., a thread created by the application), or a
377-
/// driver task. This can be triggered when a client starts a get call or a
378-
/// wait call.
379-
///
380-
/// \param client The client that is executing the blocked task.
381-
/// \param required_object_refs The objects that the client is blocked waiting for.
382-
/// \param current_task_id The task that is blocked.
383-
/// \param ray_get Whether the task is blocked in a `ray.get` call.
384-
/// \return Void.
385-
void AsyncResolveObjects(const std::shared_ptr<ClientConnection> &client,
386-
const std::vector<rpc::ObjectReference> &required_object_refs,
387-
const TaskID &current_task_id,
388-
bool ray_get);
389-
390-
/// Handle end of a blocking object get. This could be a task assigned to a
391-
/// worker, an out-of-band task (e.g., a thread created by the application),
392-
/// or a driver task. This can be triggered when a client finishes a get call
393-
/// or a wait call. The given task must be blocked, via a previous call to
394-
/// AsyncResolveObjects.
395-
///
396-
/// \param client The client that is executing the unblocked task.
397-
/// \param current_task_id The task that is unblocked.
375+
/// Start a get or wait request for the requested objects.
376+
///
377+
/// \param client The client that is requesting the objects.
378+
/// \param object_refs The objects that are requested.
379+
/// \param is_get_request If this is a get request, else it's a wait request.
398380
/// \return Void.
399-
void AsyncResolveObjectsFinish(const std::shared_ptr<ClientConnection> &client,
400-
const TaskID &current_task_id);
381+
void AsyncGetOrWait(const std::shared_ptr<ClientConnection> &client,
382+
const std::vector<rpc::ObjectReference> &object_refs,
383+
bool is_get_request);
384+
385+
/// Cancel all ongoing get requests from the client.
386+
///
387+
/// This does *not* cancel ongoing wait requests.
388+
///
389+
/// \param client The client whose get requests will be canceled.
390+
void CancelGetRequest(const std::shared_ptr<ClientConnection> &client);
401391

402392
/// Handle a task that is blocked. Note that this callback may
403393
/// arrive after the worker lease has been returned to the node manager.

src/ray/raylet_client/raylet_client.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,21 @@ Status RayletClient::ActorCreationTaskDone() {
137137

138138
Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
139139
const std::vector<rpc::Address> &owner_addresses,
140-
bool fetch_only,
141-
const TaskID &current_task_id) {
140+
bool fetch_only) {
142141
RAY_CHECK(object_ids.size() == owner_addresses.size());
143142
flatbuffers::FlatBufferBuilder fbb;
144143
auto object_ids_message = to_flatbuf(fbb, object_ids);
145-
auto message =
146-
protocol::CreateFetchOrReconstruct(fbb,
147-
object_ids_message,
148-
AddressesToFlatbuffer(fbb, owner_addresses),
149-
fetch_only,
150-
to_flatbuf(fbb, current_task_id));
144+
auto message = protocol::CreateFetchOrReconstruct(
145+
fbb, object_ids_message, AddressesToFlatbuffer(fbb, owner_addresses), fetch_only);
151146
fbb.Finish(message);
152147
return conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb);
153148
}
154149

155-
Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
150+
Status RayletClient::CancelGetRequest() {
156151
flatbuffers::FlatBufferBuilder fbb;
157-
auto message = protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
152+
auto message = protocol::CreateCancelGetRequest(fbb);
158153
fbb.Finish(message);
159-
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
154+
return conn_->WriteMessage(MessageType::CancelGetRequest, &fbb);
160155
}
161156

162157
Status RayletClient::NotifyDirectCallTaskBlocked() {
@@ -177,16 +172,14 @@ StatusOr<absl::flat_hash_set<ObjectID>> RayletClient::Wait(
177172
const std::vector<ObjectID> &object_ids,
178173
const std::vector<rpc::Address> &owner_addresses,
179174
int num_returns,
180-
int64_t timeout_milliseconds,
181-
const TaskID &current_task_id) {
175+
int64_t timeout_milliseconds) {
182176
// Write request.
183177
flatbuffers::FlatBufferBuilder fbb;
184178
auto message = protocol::CreateWaitRequest(fbb,
185179
to_flatbuf(fbb, object_ids),
186180
AddressesToFlatbuffer(fbb, owner_addresses),
187181
num_returns,
188-
timeout_milliseconds,
189-
to_flatbuf(fbb, current_task_id));
182+
timeout_milliseconds);
190183
fbb.Finish(message);
191184
std::vector<uint8_t> reply;
192185
RAY_RETURN_NOT_OK(conn_->AtomicRequestReply(

src/ray/raylet_client/raylet_client.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,12 @@ class RayletClient : public RayletClientInterface {
297297
/// \return int 0 means correct, other numbers mean error.
298298
ray::Status FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
299299
const std::vector<rpc::Address> &owner_addresses,
300-
bool fetch_only,
301-
const TaskID &current_task_id);
300+
bool fetch_only);
302301

303-
/// Notify the raylet that this client (worker) is no longer blocked.
302+
/// Tell the Raylet to cancel the get request from this worker.
304303
///
305-
/// \param current_task_id The task that is no longer blocked.
306304
/// \return ray::Status.
307-
ray::Status NotifyUnblocked(const TaskID &current_task_id);
305+
ray::Status CancelGetRequest();
308306

309307
/// Notify the raylet that this client is blocked. This is only used for direct task
310308
/// calls. Note that ordering of this with respect to Unblock calls is important.
@@ -325,7 +323,6 @@ class RayletClient : public RayletClientInterface {
325323
/// \param owner_addresses The addresses of the workers that own the objects.
326324
/// \param num_returns The number of objects to wait for.
327325
/// \param timeout_milliseconds Duration, in milliseconds, to wait before returning.
328-
/// \param current_task_id The task that called wait.
329326
/// \param result A pair with the first element containing the object ids that were
330327
/// found, and the second element the objects that were not found.
331328
/// \return ray::StatusOr containing error status or the set of object ids that were
@@ -334,8 +331,7 @@ class RayletClient : public RayletClientInterface {
334331
const std::vector<ObjectID> &object_ids,
335332
const std::vector<rpc::Address> &owner_addresses,
336333
int num_returns,
337-
int64_t timeout_milliseconds,
338-
const TaskID &current_task_id);
334+
int64_t timeout_milliseconds);
339335

340336
/// Wait for the given objects, asynchronously. The core worker is notified when
341337
/// the wait completes.

0 commit comments

Comments
 (0)