Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/ray/core_worker/store_provider/plasma_store_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception) {
const auto owner_addresses = reference_counter_.GetOwnerAddresses(batch_ids);
RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(
batch_ids, owner_addresses, fetch_only, task_id));
RAY_RETURN_NOT_OK(
raylet_client_->FetchOrReconstruct(batch_ids, owner_addresses, fetch_only));

std::vector<plasma::ObjectBuffer> plasma_results;
RAY_RETURN_NOT_OK(store_client_->Get(batch_ids,
Expand Down Expand Up @@ -273,7 +273,7 @@ Status UnblockIfNeeded(const std::shared_ptr<raylet::RayletClient> &client,
return Status::OK(); // We don't need to release resources.
}
} else {
return client->NotifyUnblocked(ctx.GetCurrentTaskID());
return client->CancelGetRequest();
}
}

Expand Down Expand Up @@ -403,12 +403,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait(
}

const auto owner_addresses = reference_counter_.GetOwnerAddresses(id_vector);
RAY_ASSIGN_OR_RETURN(ready_in_plasma,
raylet_client_->Wait(id_vector,
owner_addresses,
num_objects,
call_timeout,
ctx.GetCurrentTaskID()));
RAY_ASSIGN_OR_RETURN(
ready_in_plasma,
raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout));

if (ready_in_plasma.size() >= static_cast<size_t>(num_objects)) {
should_break = true;
Expand Down
13 changes: 3 additions & 10 deletions src/ray/raylet/format/node_manager.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ enum MessageType:int {
// Reconstruct or fetch possibly lost objects. This is sent from a worker to
// a raylet.
FetchOrReconstruct,
// For a worker that was blocked on some object(s), tell the raylet
// that the worker is now unblocked. This is sent from a worker to a raylet.
NotifyUnblocked,
// Cancel outstanding get requests from the worker.
CancelGetRequest,
// Notify the current worker is blocked. This is only used by direct task calls;
// otherwise the block command is piggybacked on other messages.
NotifyDirectCallTaskBlocked,
Expand Down Expand Up @@ -176,13 +175,9 @@ table FetchOrReconstruct {
owner_addresses: [Address];
// Do we only want to fetch the objects or also reconstruct them?
fetch_only: bool;
// The current task ID.
task_id: string;
}

table NotifyUnblocked {
// The current task ID. This task is no longer blocked.
task_id: string;
table CancelGetRequest {
}

table NotifyDirectCallTaskBlocked {
Expand All @@ -201,8 +196,6 @@ table WaitRequest {
num_required_objects: int;
// timeout
timeout: long;
// The current task ID.
task_id: string;
}

table WaitReply {
Expand Down
91 changes: 29 additions & 62 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,11 +1043,8 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
HandleDirectCallTaskUnblocked(registered_worker);
} break;
case protocol::MessageType::NotifyUnblocked: {
// TODO(ekl) this is still used from core worker even in direct call mode to
// finish up get requests.
auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
AsyncResolveObjectsFinish(client, from_flatbuf<TaskID>(*message->task_id()));
case protocol::MessageType::CancelGetRequest: {
CancelGetRequest(client);
} break;
case protocol::MessageType::WaitRequest: {
ProcessWaitRequestMessage(client, message_data);
Expand Down Expand Up @@ -1492,11 +1489,9 @@ void NodeManager::ProcessFetchOrReconstructMessage(
// subscribe to in the task dependency manager. These objects will be
// pulled from remote node managers. If an object's owner dies, an error
// will be stored as the object's value.
const TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
AsyncResolveObjects(client,
refs,
task_id,
/*ray_get=*/true);
AsyncGetOrWait(client,
refs,
/*is_get_request=*/true);
}
}

Expand All @@ -1508,28 +1503,24 @@ void NodeManager::ProcessWaitRequestMessage(
const auto refs =
FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses());

bool resolve_objects = false;
bool all_objects_local = true;
for (auto const &object_id : object_ids) {
if (!dependency_manager_.CheckObjectLocal(object_id)) {
// At least one object requires resolution.
resolve_objects = true;
all_objects_local = false;
}
}

const TaskID &current_task_id = from_flatbuf<TaskID>(*message->task_id());
if (resolve_objects) {
if (!all_objects_local) {
// Resolve any missing objects. This is a no-op for any objects that are
// already local. Missing objects will be pulled from remote node managers.
// If an object's owner dies, an error will be stored as the object's
// value.
AsyncResolveObjects(client,
refs,
current_task_id,
/*ray_get=*/false);
AsyncGetOrWait(client, refs, /*is_get_request=*/false);
}

if (message->num_required_objects() == 0) {
// If we don't need to wait for any, return immediately after making the pull
// requests through AsyncResolveObjects above.
// requests through AsyncGetOrWait above.
flatbuffers::FlatBufferBuilder fbb;
auto wait_reply = protocol::CreateWaitReply(fbb,
to_flatbuf(fbb, std::vector<ObjectID>{}),
Expand All @@ -1539,11 +1530,7 @@ void NodeManager::ProcessWaitRequestMessage(
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
fbb.GetSize(),
fbb.GetBufferPointer());
if (status.ok()) {
if (resolve_objects) {
AsyncResolveObjectsFinish(client, current_task_id);
}
} else {
if (!status.ok()) {
// We failed to write to the client, so disconnect the client.
std::ostringstream stream;
stream << "Failed to write WaitReply to the client. Status " << status;
Expand All @@ -1557,8 +1544,8 @@ void NodeManager::ProcessWaitRequestMessage(
object_ids,
message->timeout(),
num_required_objects,
[this, resolve_objects, client, current_task_id](std::vector<ObjectID> ready,
std::vector<ObjectID> remaining) {
[this, client, all_objects_local](std::vector<ObjectID> ready,
std::vector<ObjectID> remaining) {
// Write the data.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
Expand All @@ -1570,10 +1557,8 @@ void NodeManager::ProcessWaitRequestMessage(
fbb.GetSize(),
fbb.GetBufferPointer());
if (status.ok()) {
// The client is unblocked now because the wait call has
// returned.
if (resolve_objects) {
AsyncResolveObjectsFinish(client, current_task_id);
if (!all_objects_local) {
CancelGetRequest(client);
}
} else {
// We failed to write to the client, so disconnect the client.
Expand All @@ -1589,19 +1574,14 @@ void NodeManager::ProcessWaitRequestMessage(

void NodeManager::ProcessWaitForActorCallArgsRequestMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
// Read the data.
auto message =
flatbuffers::GetRoot<protocol::WaitForActorCallArgsRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids());
int64_t tag = message->tag();
// Resolve any missing objects. This will pull the objects from remote node
// managers or store an error if the objects have failed.
// Pull any missing objects to the local node.
const auto refs =
FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses());
AsyncResolveObjects(client,
refs,
TaskID::Nil(),
/*ray_get=*/false);
AsyncGetOrWait(client, refs, /*is_get_request=*/false);
// De-duplicate the object IDs.
absl::flat_hash_set<ObjectID> object_id_set(object_ids.begin(), object_ids.end());
object_ids.assign(object_id_set.begin(), object_id_set.end());
Expand Down Expand Up @@ -2096,44 +2076,31 @@ void NodeManager::HandleDirectCallTaskUnblocked(
}
}

void NodeManager::AsyncResolveObjects(
const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &required_object_refs,
const TaskID &current_task_id,
bool ray_get) {
void NodeManager::AsyncGetOrWait(const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &object_refs,
bool is_get_request) {
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (!worker) {
// The client is a driver. Drivers do not hold resources, so we simply mark
// the task as blocked.
worker = worker_pool_.GetRegisteredDriver(client);
}

RAY_CHECK(worker);
// Subscribe to the objects required by the task. These objects will be
// fetched and/or restarted as necessary, until the objects become local
// or are unsubscribed.
if (ray_get) {
dependency_manager_.StartOrUpdateGetRequest(worker->WorkerId(), required_object_refs);

// Start an async request to get or wait for the objects.
// The objects will be fetched locally unless the get or wait request is canceled.
if (is_get_request) {
dependency_manager_.StartOrUpdateGetRequest(worker->WorkerId(), object_refs);
} else {
dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(),
required_object_refs);
dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(), object_refs);
}
}

void NodeManager::AsyncResolveObjectsFinish(
const std::shared_ptr<ClientConnection> &client, const TaskID &current_task_id) {
void NodeManager::CancelGetRequest(const std::shared_ptr<ClientConnection> &client) {
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (!worker) {
// The client is a driver. Drivers do not hold resources, so we simply
// mark the driver as unblocked.
worker = worker_pool_.GetRegisteredDriver(client);
}

RAY_CHECK(worker);
// Unsubscribe from any `ray.get` objects that the task was blocked on. Any
// fetch or reconstruction operations to make the objects local are canceled.
// `ray.wait` calls will stay active until the objects become local, or the
// task/actor that called `ray.wait` exits.

dependency_manager_.CancelGetRequest(worker->WorkerId());
}

Expand Down
40 changes: 15 additions & 25 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,32 +372,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
void FinishAssignedActorCreationTask(const std::shared_ptr<WorkerInterface> &worker,
const RayTask &task);

/// Handle blocking gets of objects. This could be a task assigned to a worker,
/// an out-of-band task (e.g., a thread created by the application), or a
/// driver task. This can be triggered when a client starts a get call or a
/// wait call.
///
/// \param client The client that is executing the blocked task.
/// \param required_object_refs The objects that the client is blocked waiting for.
/// \param current_task_id The task that is blocked.
/// \param ray_get Whether the task is blocked in a `ray.get` call.
/// \return Void.
void AsyncResolveObjects(const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &required_object_refs,
const TaskID &current_task_id,
bool ray_get);

/// Handle end of a blocking object get. This could be a task assigned to a
/// worker, an out-of-band task (e.g., a thread created by the application),
/// or a driver task. This can be triggered when a client finishes a get call
/// or a wait call. The given task must be blocked, via a previous call to
/// AsyncResolveObjects.
///
/// \param client The client that is executing the unblocked task.
/// \param current_task_id The task that is unblocked.
/// Start a get or wait request for the requested objects.
///
/// \param client The client that is requesting the objects.
/// \param object_refs The objects that are requested.
/// \param is_get_request If this is a get request, else it's a wait request.
/// \return Void.
void AsyncResolveObjectsFinish(const std::shared_ptr<ClientConnection> &client,
const TaskID &current_task_id);
void AsyncGetOrWait(const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &object_refs,
bool is_get_request);

/// Cancel all ongoing get requests from the client.
///
/// This does *not* cancel ongoing wait requests.
///
/// \param client The client whose get requests will be canceled.
void CancelGetRequest(const std::shared_ptr<ClientConnection> &client);

/// Handle a task that is blocked. Note that this callback may
/// arrive after the worker lease has been returned to the node manager.
Expand Down
23 changes: 8 additions & 15 deletions src/ray/raylet_client/raylet_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,21 @@ Status RayletClient::ActorCreationTaskDone() {

Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
const std::vector<rpc::Address> &owner_addresses,
bool fetch_only,
const TaskID &current_task_id) {
bool fetch_only) {
RAY_CHECK(object_ids.size() == owner_addresses.size());
flatbuffers::FlatBufferBuilder fbb;
auto object_ids_message = to_flatbuf(fbb, object_ids);
auto message =
protocol::CreateFetchOrReconstruct(fbb,
object_ids_message,
AddressesToFlatbuffer(fbb, owner_addresses),
fetch_only,
to_flatbuf(fbb, current_task_id));
auto message = protocol::CreateFetchOrReconstruct(
fbb, object_ids_message, AddressesToFlatbuffer(fbb, owner_addresses), fetch_only);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb);
}

Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
Status RayletClient::CancelGetRequest() {
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
auto message = protocol::CreateCancelGetRequest(fbb);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
return conn_->WriteMessage(MessageType::CancelGetRequest, &fbb);
}

Status RayletClient::NotifyDirectCallTaskBlocked() {
Expand All @@ -177,16 +172,14 @@ StatusOr<absl::flat_hash_set<ObjectID>> RayletClient::Wait(
const std::vector<ObjectID> &object_ids,
const std::vector<rpc::Address> &owner_addresses,
int num_returns,
int64_t timeout_milliseconds,
const TaskID &current_task_id) {
int64_t timeout_milliseconds) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateWaitRequest(fbb,
to_flatbuf(fbb, object_ids),
AddressesToFlatbuffer(fbb, owner_addresses),
num_returns,
timeout_milliseconds,
to_flatbuf(fbb, current_task_id));
timeout_milliseconds);
fbb.Finish(message);
std::vector<uint8_t> reply;
RAY_RETURN_NOT_OK(conn_->AtomicRequestReply(
Expand Down
12 changes: 4 additions & 8 deletions src/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,12 @@ class RayletClient : public RayletClientInterface {
/// \return int 0 means correct, other numbers mean error.
ray::Status FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
const std::vector<rpc::Address> &owner_addresses,
bool fetch_only,
const TaskID &current_task_id);
bool fetch_only);

/// Notify the raylet that this client (worker) is no longer blocked.
/// Tell the Raylet to cancel the get request from this worker.
///
/// \param current_task_id The task that is no longer blocked.
/// \return ray::Status.
ray::Status NotifyUnblocked(const TaskID &current_task_id);
ray::Status CancelGetRequest();

/// Notify the raylet that this client is blocked. This is only used for direct task
/// calls. Note that ordering of this with respect to Unblock calls is important.
Expand All @@ -325,7 +323,6 @@ class RayletClient : public RayletClientInterface {
/// \param owner_addresses The addresses of the workers that own the objects.
/// \param num_returns The number of objects to wait for.
/// \param timeout_milliseconds Duration, in milliseconds, to wait before returning.
/// \param current_task_id The task that called wait.
/// \param result A pair with the first element containing the object ids that were
/// found, and the second element the objects that were not found.
/// \return ray::StatusOr containing error status or the set of object ids that were
Expand All @@ -334,8 +331,7 @@ class RayletClient : public RayletClientInterface {
const std::vector<ObjectID> &object_ids,
const std::vector<rpc::Address> &owner_addresses,
int num_returns,
int64_t timeout_milliseconds,
const TaskID &current_task_id);
int64_t timeout_milliseconds);

/// Wait for the given objects, asynchronously. The core worker is notified when
/// the wait completes.
Expand Down