@@ -1043,11 +1043,8 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
10431043 case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
10441044 HandleDirectCallTaskUnblocked (registered_worker);
10451045 } break ;
1046- case protocol::MessageType::NotifyUnblocked: {
1047- // TODO(ekl) this is still used from core worker even in direct call mode to
1048- // finish up get requests.
1049- auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
1050- AsyncResolveObjectsFinish (client, from_flatbuf<TaskID>(*message->task_id ()));
1046+ case protocol::MessageType::CancelGetRequest: {
1047+ CancelGetRequest (client);
10511048 } break ;
10521049 case protocol::MessageType::WaitRequest: {
10531050 ProcessWaitRequestMessage (client, message_data);
@@ -1492,11 +1489,9 @@ void NodeManager::ProcessFetchOrReconstructMessage(
14921489 // subscribe to in the task dependency manager. These objects will be
14931490 // pulled from remote node managers. If an object's owner dies, an error
14941491 // will be stored as the object's value.
1495- const TaskID task_id = from_flatbuf<TaskID>(*message->task_id ());
1496- AsyncResolveObjects (client,
1497- refs,
1498- task_id,
1499- /* ray_get=*/ true );
1492+ AsyncGetOrWait (client,
1493+ refs,
1494+ /* is_get_request=*/ true );
15001495 }
15011496}
15021497
@@ -1508,28 +1503,24 @@ void NodeManager::ProcessWaitRequestMessage(
15081503 const auto refs =
15091504 FlatbufferToObjectReference (*message->object_ids (), *message->owner_addresses ());
15101505
1511- bool resolve_objects = false ;
1506+ bool all_objects_local = true ;
15121507 for (auto const &object_id : object_ids) {
15131508 if (!dependency_manager_.CheckObjectLocal (object_id)) {
1514- // At least one object requires resolution.
1515- resolve_objects = true ;
1509+ all_objects_local = false ;
15161510 }
15171511 }
15181512
1519- const TaskID ¤t_task_id = from_flatbuf<TaskID>(*message->task_id ());
1520- if (resolve_objects) {
1513+ if (!all_objects_local) {
15211514 // Resolve any missing objects. This is a no-op for any objects that are
15221515 // already local. Missing objects will be pulled from remote node managers.
15231516 // If an object's owner dies, an error will be stored as the object's
15241517 // value.
1525- AsyncResolveObjects (client,
1526- refs,
1527- current_task_id,
1528- /* ray_get=*/ false );
1518+ AsyncGetOrWait (client, refs, /* is_get_request=*/ false );
15291519 }
1520+
15301521 if (message->num_required_objects () == 0 ) {
15311522 // If we don't need to wait for any, return immediately after making the pull
1532- // requests through AsyncResolveObjects above.
1523+ // requests through AsyncGetOrWait above.
15331524 flatbuffers::FlatBufferBuilder fbb;
15341525 auto wait_reply = protocol::CreateWaitReply (fbb,
15351526 to_flatbuf (fbb, std::vector<ObjectID>{}),
@@ -1539,11 +1530,7 @@ void NodeManager::ProcessWaitRequestMessage(
15391530 client->WriteMessage (static_cast <int64_t >(protocol::MessageType::WaitReply),
15401531 fbb.GetSize (),
15411532 fbb.GetBufferPointer ());
1542- if (status.ok ()) {
1543- if (resolve_objects) {
1544- AsyncResolveObjectsFinish (client, current_task_id);
1545- }
1546- } else {
1533+ if (!status.ok ()) {
15471534 // We failed to write to the client, so disconnect the client.
15481535 std::ostringstream stream;
15491536 stream << " Failed to write WaitReply to the client. Status " << status;
@@ -1557,8 +1544,8 @@ void NodeManager::ProcessWaitRequestMessage(
15571544 object_ids,
15581545 message->timeout (),
15591546 num_required_objects,
1560- [this , resolve_objects, client, current_task_id ](std::vector<ObjectID> ready,
1561- std::vector<ObjectID> remaining) {
1547+ [this , client, all_objects_local ](std::vector<ObjectID> ready,
1548+ std::vector<ObjectID> remaining) {
15621549 // Write the data.
15631550 flatbuffers::FlatBufferBuilder fbb;
15641551 flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply (
@@ -1570,10 +1557,8 @@ void NodeManager::ProcessWaitRequestMessage(
15701557 fbb.GetSize (),
15711558 fbb.GetBufferPointer ());
15721559 if (status.ok ()) {
1573- // The client is unblocked now because the wait call has
1574- // returned.
1575- if (resolve_objects) {
1576- AsyncResolveObjectsFinish (client, current_task_id);
1560+ if (!all_objects_local) {
1561+ CancelGetRequest (client);
15771562 }
15781563 } else {
15791564 // We failed to write to the client, so disconnect the client.
@@ -1589,19 +1574,14 @@ void NodeManager::ProcessWaitRequestMessage(
15891574
15901575void NodeManager::ProcessWaitForActorCallArgsRequestMessage (
15911576 const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
1592- // Read the data.
15931577 auto message =
15941578 flatbuffers::GetRoot<protocol::WaitForActorCallArgsRequest>(message_data);
15951579 std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids ());
15961580 int64_t tag = message->tag ();
1597- // Resolve any missing objects. This will pull the objects from remote node
1598- // managers or store an error if the objects have failed.
1581+ // Pull any missing objects to the local node.
15991582 const auto refs =
16001583 FlatbufferToObjectReference (*message->object_ids (), *message->owner_addresses ());
1601- AsyncResolveObjects (client,
1602- refs,
1603- TaskID::Nil (),
1604- /* ray_get=*/ false );
1584+ AsyncGetOrWait (client, refs, /* is_get_request=*/ false );
16051585 // De-duplicate the object IDs.
16061586 absl::flat_hash_set<ObjectID> object_id_set (object_ids.begin (), object_ids.end ());
16071587 object_ids.assign (object_id_set.begin (), object_id_set.end ());
@@ -2096,44 +2076,31 @@ void NodeManager::HandleDirectCallTaskUnblocked(
20962076 }
20972077}
20982078
2099- void NodeManager::AsyncResolveObjects (
2100- const std::shared_ptr<ClientConnection> &client,
2101- const std::vector<rpc::ObjectReference> &required_object_refs,
2102- const TaskID ¤t_task_id,
2103- bool ray_get) {
2079+ void NodeManager::AsyncGetOrWait (const std::shared_ptr<ClientConnection> &client,
2080+ const std::vector<rpc::ObjectReference> &object_refs,
2081+ bool is_get_request) {
21042082 std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker (client);
21052083 if (!worker) {
2106- // The client is a driver. Drivers do not hold resources, so we simply mark
2107- // the task as blocked.
21082084 worker = worker_pool_.GetRegisteredDriver (client);
21092085 }
2110-
21112086 RAY_CHECK (worker);
2112- // Subscribe to the objects required by the task. These objects will be
2113- // fetched and/or restarted as necessary, until the objects become local
2114- // or are unsubscribed .
2115- if (ray_get ) {
2116- dependency_manager_.StartOrUpdateGetRequest (worker->WorkerId (), required_object_refs );
2087+
2088+ // Start an async request to get or wait for the objects.
2089+ // The objects will be fetched locally unless the get or wait request is canceled .
2090+ if (is_get_request ) {
2091+ dependency_manager_.StartOrUpdateGetRequest (worker->WorkerId (), object_refs );
21172092 } else {
2118- dependency_manager_.StartOrUpdateWaitRequest (worker->WorkerId (),
2119- required_object_refs);
2093+ dependency_manager_.StartOrUpdateWaitRequest (worker->WorkerId (), object_refs);
21202094 }
21212095}
21222096
2123- void NodeManager::AsyncResolveObjectsFinish (
2124- const std::shared_ptr<ClientConnection> &client, const TaskID ¤t_task_id) {
2097+ void NodeManager::CancelGetRequest (const std::shared_ptr<ClientConnection> &client) {
21252098 std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker (client);
21262099 if (!worker) {
2127- // The client is a driver. Drivers do not hold resources, so we simply
2128- // mark the driver as unblocked.
21292100 worker = worker_pool_.GetRegisteredDriver (client);
21302101 }
2131-
21322102 RAY_CHECK (worker);
2133- // Unsubscribe from any `ray.get` objects that the task was blocked on. Any
2134- // fetch or reconstruction operations to make the objects local are canceled.
2135- // `ray.wait` calls will stay active until the objects become local, or the
2136- // task/actor that called `ray.wait` exits.
2103+
21372104 dependency_manager_.CancelGetRequest (worker->WorkerId ());
21382105}
21392106
0 commit comments