diff --git a/src/ray/raylet/lease_dependency_manager.cc b/src/ray/raylet/lease_dependency_manager.cc index a86a6313bdba..357b11ae85c1 100644 --- a/src/ray/raylet/lease_dependency_manager.cc +++ b/src/ray/raylet/lease_dependency_manager.cc @@ -126,7 +126,7 @@ void LeaseDependencyManager::StartGetRequest( const auto obj_id = ObjectRefToId(ref); object_ids.emplace_back(obj_id); auto it = GetOrInsertRequiredObject(obj_id, ref); - it->second.dependent_get_requests.insert(worker_id); + ++it->second.dependent_get_requests[worker_id]; } uint64_t new_pull_request_id = object_manager_.Pull( @@ -158,7 +158,10 @@ void LeaseDependencyManager::CancelGetRequest(const WorkerID &worker_id, for (const auto &obj_id : object_ids) { auto obj_iter = required_objects_.find(obj_id); RAY_CHECK(obj_iter != required_objects_.end()); - obj_iter->second.dependent_get_requests.erase(worker_id); + --obj_iter->second.dependent_get_requests[worker_id]; + if (obj_iter->second.dependent_get_requests[worker_id] == 0) { + obj_iter->second.dependent_get_requests.erase(worker_id); + } RemoveObjectIfNotNeeded(obj_iter); } @@ -193,7 +196,10 @@ void LeaseDependencyManager::CancelGetRequest(const WorkerID &worker_id) { for (const auto &obj_id : object_ids) { auto obj_iter = required_objects_.find(obj_id); RAY_CHECK(obj_iter != required_objects_.end()); - obj_iter->second.dependent_get_requests.erase(worker_id); + --obj_iter->second.dependent_get_requests[worker_id]; + if (obj_iter->second.dependent_get_requests[worker_id] == 0) { + obj_iter->second.dependent_get_requests.erase(worker_id); + } RemoveObjectIfNotNeeded(obj_iter); } diff --git a/src/ray/raylet/lease_dependency_manager.h b/src/ray/raylet/lease_dependency_manager.h index 7e71a0a9b59f..439ce19aa407 100644 --- a/src/ray/raylet/lease_dependency_manager.h +++ b/src/ray/raylet/lease_dependency_manager.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -218,8 +219,8 @@ class LeaseDependencyManager : public LeaseDependencyManagerInterface { /// argument or because the lease of the lease called `ray.get` on the object. std::unordered_set dependent_leases; /// The workers that depend on this object because they called `ray.get` on the - /// object. - std::unordered_set dependent_get_requests; + /// object and the count of outstanding get_requests per worker. + std::unordered_map dependent_get_requests; /// The workers that depend on this object because they called `ray.wait` on the /// object. std::unordered_set dependent_wait_requests; diff --git a/src/ray/raylet/tests/lease_dependency_manager_test.cc b/src/ray/raylet/tests/lease_dependency_manager_test.cc index 18dcb2e79095..3ba6ebd80534 100644 --- a/src/ray/raylet/tests/lease_dependency_manager_test.cc +++ b/src/ray/raylet/tests/lease_dependency_manager_test.cc @@ -256,6 +256,23 @@ TEST_F(LeaseDependencyManagerTest, TestCancelingSingleGetRequestForWorker) { AssertNoLeaks(); } +TEST_F(LeaseDependencyManagerTest, + TestCancelingMultipleGetRequestsForSameObjectForWorker) { + WorkerID worker_id = WorkerID::FromRandom(); + ObjectID argument_id = ObjectID::FromRandom(); + int num_requests = 5; + for (int64_t i = 0; i < num_requests; i++) { + lease_dependency_manager_.StartGetRequest( + worker_id, ObjectIdsToRefs({argument_id}), i); + } + ASSERT_EQ(object_manager_mock_.active_get_requests.size(), num_requests); + for (int64_t i = 0; i < num_requests; i++) { + lease_dependency_manager_.CancelGetRequest(worker_id, i); + ASSERT_EQ(object_manager_mock_.active_get_requests.size(), num_requests - (i + 1)); + } + AssertNoLeaks(); +} + TEST_F(LeaseDependencyManagerTest, TestCancelingAllGetRequestsForWorker) { WorkerID worker_id = WorkerID::FromRandom(); int num_requests = 5;