From 5eb42d244934e324f0ab3676c7964c3a942027be Mon Sep 17 00:00:00 2001 From: yuchao-wang Date: Thu, 13 Nov 2025 22:36:03 +0800 Subject: [PATCH] rename gcs_callback_types.h to grpc_callback_types.h Signed-off-by: yuchao-wang --- .../gcs/store_client/in_memory_store_client.h | 2 +- .../ray/gcs/store_client/redis_store_client.h | 2 +- src/mock/ray/gcs/store_client/store_client.h | 2 +- src/mock/ray/gcs_client/accessor.h | 64 +++++----- .../accessors/actor_info_accessor.h | 24 ++-- src/ray/common/BUILD.bazel | 8 -- src/ray/common/gcs_callback_types.h | 55 -------- src/ray/core_worker/actor_creator.cc | 10 +- src/ray/core_worker/actor_creator.h | 18 +-- src/ray/core_worker/fake_actor_creator.h | 10 +- .../tests/task_event_buffer_test.cc | 14 +-- .../tests/core_worker_client_pool_test.cc | 8 +- src/ray/gcs/gcs_placement_group_manager.cc | 7 +- src/ray/gcs/gcs_placement_group_manager.h | 10 +- src/ray/gcs/gcs_placement_group_scheduler.cc | 4 +- src/ray/gcs/gcs_placement_group_scheduler.h | 4 +- src/ray/gcs/store_client/BUILD.bazel | 2 +- .../store_client/in_memory_store_client.cc | 2 +- .../gcs/store_client/in_memory_store_client.h | 2 +- .../store_client/observable_store_client.cc | 2 +- .../store_client/observable_store_client.h | 2 +- .../gcs/store_client/redis_store_client.cc | 7 +- src/ray/gcs/store_client/redis_store_client.h | 2 +- src/ray/gcs/store_client/store_client.h | 4 +- .../tests/gcs_placement_group_manager_test.cc | 3 +- src/ray/gcs_rpc_client/BUILD.bazel | 2 +- src/ray/gcs_rpc_client/accessor.cc | 119 +++++++++--------- src/ray/gcs_rpc_client/accessor.h | 84 +++++++------ .../accessors/actor_info_accessor.cc | 23 ++-- .../accessors/actor_info_accessor.h | 26 ++-- .../accessors/actor_info_accessor_interface.h | 38 +++--- .../gcs_rpc_client/global_state_accessor.h | 8 +- .../tests/gcs_client_injectable_test.cc | 22 ++-- .../gcs_rpc_client/tests/gcs_client_test.cc | 6 +- src/ray/pubsub/BUILD.bazel | 2 +- src/ray/pubsub/gcs_subscriber.cc | 20 +-- src/ray/pubsub/gcs_subscriber.h | 22 ++-- src/ray/raylet/tests/node_manager_test.cc | 16 +-- src/ray/raylet_rpc_client/BUILD.bazel | 1 - src/ray/raylet_rpc_client/raylet_client.cc | 2 +- src/ray/raylet_rpc_client/raylet_client.h | 3 +- .../raylet_client_with_io_context.cc | 6 +- .../raylet_client_with_io_context.h | 4 +- .../tests/raylet_client_pool_test.cc | 4 +- src/ray/rpc/rpc_callback_types.h | 33 +++++ 45 files changed, 349 insertions(+), 360 deletions(-) delete mode 100644 src/ray/common/gcs_callback_types.h diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h index 16a7a5cab895..85e138aeef04 100644 --- a/src/mock/ray/gcs/store_client/in_memory_store_client.h +++ b/src/mock/ray/gcs/store_client/in_memory_store_client.h @@ -30,7 +30,7 @@ class MockInMemoryStoreClient : public InMemoryStoreClient { AsyncGet, (const std::string &table_name, const std::string &key, - ToPostable> callback), + ToPostable> callback), (override)); MOCK_METHOD(void, diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h index 7a73e5b045dd..5027fd4b177c 100644 --- a/src/mock/ray/gcs/store_client/redis_store_client.h +++ b/src/mock/ray/gcs/store_client/redis_store_client.h @@ -29,7 +29,7 @@ class MockStoreClient : public StoreClient { AsyncGet, (const std::string &table_name, const std::string &key, - ToPostable> callback), + ToPostable> callback), (override)); MOCK_METHOD(void, AsyncGetAll, diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h index 7a73e5b045dd..5027fd4b177c 100644 --- a/src/mock/ray/gcs/store_client/store_client.h +++ b/src/mock/ray/gcs/store_client/store_client.h @@ -29,7 +29,7 @@ class MockStoreClient : public StoreClient { AsyncGet, (const std::string &table_name, const std::string &key, - ToPostable> callback), + ToPostable> callback), (override)); MOCK_METHOD(void, AsyncGetAll, diff --git a/src/mock/ray/gcs_client/accessor.h b/src/mock/ray/gcs_client/accessor.h index 463f8939cce8..7448ead11a03 100644 --- a/src/mock/ray/gcs_client/accessor.h +++ b/src/mock/ray/gcs_client/accessor.h @@ -23,27 +23,30 @@ class MockJobInfoAccessor : public JobInfoAccessor { MOCK_METHOD(void, AsyncAdd, (const std::shared_ptr &data_ptr, - const StatusCallback &callback), + const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncMarkFinished, - (const JobID &job_id, const StatusCallback &callback), + (const JobID &job_id, const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncSubscribeAll, - ((const SubscribeCallback &subscribe), - const StatusCallback &done), + ((const rpc::SubscribeCallback &subscribe), + const rpc::StatusCallback &done), (override)); MOCK_METHOD(void, AsyncGetAll, (const std::optional &job_or_submission_id, bool skip_submission_job_info_field, bool skip_is_running_tasks_field, - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms), (override)); MOCK_METHOD(void, AsyncResubscribe, (), (override)); - MOCK_METHOD(void, AsyncGetNextJobID, (const ItemCallback &callback), (override)); + MOCK_METHOD(void, + AsyncGetNextJobID, + (const rpc::ItemCallback &callback), + (override)); }; } // namespace gcs @@ -56,27 +59,27 @@ class MockNodeInfoAccessor : public NodeInfoAccessor { public: MOCK_METHOD(void, RegisterSelf, - (rpc::GcsNodeInfo && local_node_info, const StatusCallback &callback), + (rpc::GcsNodeInfo && local_node_info, const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncRegister, - (const rpc::GcsNodeInfo &node_info, const StatusCallback &callback), + (const rpc::GcsNodeInfo &node_info, const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncCheckAlive, (const std::vector &node_ids, int64_t timeout_ms, - const MultiItemCallback &callback), + const rpc::MultiItemCallback &callback), (override)); MOCK_METHOD(void, AsyncGetAll, - (const MultiItemCallback &callback, + (const rpc::MultiItemCallback &callback, int64_t timeout_ms, const std::vector &node_ids), (override)); MOCK_METHOD(void, AsyncGetAllNodeAddressAndLiveness, - (const MultiItemCallback &callback, + (const rpc::MultiItemCallback &callback, int64_t timeout_ms, const std::vector &node_ids), (override)); @@ -84,7 +87,7 @@ class MockNodeInfoAccessor : public NodeInfoAccessor { void, AsyncSubscribeToNodeAddressAndLivenessChange, (std::function subscribe, - StatusCallback done), + rpc::StatusCallback done), (override)); MOCK_METHOD(std::optional, GetNodeAddressAndLiveness, @@ -115,11 +118,11 @@ class MockNodeResourceInfoAccessor : public NodeResourceInfoAccessor { public: MOCK_METHOD(void, AsyncGetAllAvailableResources, - (const MultiItemCallback &callback), + (const rpc::MultiItemCallback &callback), (override)); MOCK_METHOD(void, AsyncGetAllResourceUsage, - (const ItemCallback &callback), + (const rpc::ItemCallback &callback), (override)); }; @@ -144,7 +147,8 @@ class MockTaskInfoAccessor : public TaskInfoAccessor { public: MOCK_METHOD(void, AsyncAddTaskEventData, - (std::unique_ptr data_ptr, StatusCallback callback), + (std::unique_ptr data_ptr, + rpc::StatusCallback callback), (override)); }; @@ -158,27 +162,27 @@ class MockWorkerInfoAccessor : public WorkerInfoAccessor { public: MOCK_METHOD(void, AsyncSubscribeToWorkerFailures, - (const ItemCallback &subscribe, - const StatusCallback &done), + (const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done), (override)); MOCK_METHOD(void, AsyncReportWorkerFailure, (const std::shared_ptr &data_ptr, - const StatusCallback &callback), + const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncGet, (const WorkerID &worker_id, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncGetAll, - (const MultiItemCallback &callback), + (const rpc::MultiItemCallback &callback), (override)); MOCK_METHOD(void, AsyncAdd, (const std::shared_ptr &data_ptr, - const StatusCallback &callback), + const rpc::StatusCallback &callback), (override)); MOCK_METHOD(void, AsyncResubscribe, (), (override)); }; @@ -198,18 +202,18 @@ class MockPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { MOCK_METHOD(void, AsyncGet, (const PlacementGroupID &placement_group_id, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncGetByName, (const std::string &placement_group_name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms), (override)); MOCK_METHOD(void, AsyncGetAll, - (const MultiItemCallback &callback), + (const rpc::MultiItemCallback &callback), (override)); MOCK_METHOD(Status, SyncRemovePlacementGroup, @@ -234,14 +238,14 @@ class MockInternalKVAccessor : public InternalKVAccessor { (const std::string &ns, const std::string &prefix, const int64_t timeout_ms, - const OptionalItemCallback> &callback), + const rpc::OptionalItemCallback> &callback), (override)); MOCK_METHOD(void, AsyncInternalKVGet, (const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncInternalKVPut, @@ -250,14 +254,14 @@ class MockInternalKVAccessor : public InternalKVAccessor { const std::string &value, bool overwrite, const int64_t timeout_ms, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncInternalKVExists, (const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncInternalKVDel, @@ -265,11 +269,11 @@ class MockInternalKVAccessor : public InternalKVAccessor { const std::string &key, bool del_by_prefix, const int64_t timeout_ms, - const OptionalItemCallback &callback), + const rpc::OptionalItemCallback &callback), (override)); MOCK_METHOD(void, AsyncGetInternalConfig, - (const OptionalItemCallback &callback), + (const rpc::OptionalItemCallback &callback), (override)); }; diff --git a/src/mock/ray/gcs_client/accessors/actor_info_accessor.h b/src/mock/ray/gcs_client/accessors/actor_info_accessor.h index 8652db89bd10..acff258b83d7 100644 --- a/src/mock/ray/gcs_client/accessors/actor_info_accessor.h +++ b/src/mock/ray/gcs_client/accessors/actor_info_accessor.h @@ -32,15 +32,15 @@ class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface { // Stub implementations for interface methods not used by this test void AsyncGet(const ActorID &, - const gcs::OptionalItemCallback &) override {} + const rpc::OptionalItemCallback &) override {} void AsyncGetAllByFilter(const std::optional &, const std::optional &, const std::optional &, - const gcs::MultiItemCallback &, + const rpc::MultiItemCallback &, int64_t = -1) override {} void AsyncGetByName(const std::string &, const std::string &, - const gcs::OptionalItemCallback &, + const rpc::OptionalItemCallback &, int64_t = -1) override {} Status SyncGetByName(const std::string &, const std::string &, @@ -56,20 +56,20 @@ class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface { } void AsyncReportActorOutOfScope(const ActorID &, uint64_t, - const gcs::StatusCallback &, + const rpc::StatusCallback &, int64_t = -1) override {} void AsyncRegisterActor(const TaskSpecification &task_spec, - const gcs::StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t = -1) override { async_register_actor_callback_ = callback; } void AsyncRestartActorForLineageReconstruction(const ActorID &, uint64_t, - const gcs::StatusCallback &, + const rpc::StatusCallback &, int64_t = -1) override {} Status SyncRegisterActor(const TaskSpecification &) override { return Status::OK(); } void AsyncKillActor( - const ActorID &, bool, bool, const gcs::StatusCallback &, int64_t = -1) override {} + const ActorID &, bool, bool, const rpc::StatusCallback &, int64_t = -1) override {} void AsyncCreateActor( const TaskSpecification &task_spec, const rpc::ClientCallback &callback) override { @@ -78,8 +78,8 @@ class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface { void AsyncSubscribe( const ActorID &actor_id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) override { auto callback_entry = std::make_pair(actor_id, subscribe); callback_map_.emplace(actor_id, subscribe); subscribe_finished_callback_map_[actor_id] = done; @@ -124,14 +124,14 @@ class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface { return true; } - absl::flat_hash_map> + absl::flat_hash_map> callback_map_; - absl::flat_hash_map subscribe_finished_callback_map_; + absl::flat_hash_map subscribe_finished_callback_map_; absl::flat_hash_map actor_subscribed_times_; // Callbacks for AsyncCreateActor and AsyncRegisterActor rpc::ClientCallback async_create_actor_callback_; - gcs::StatusCallback async_register_actor_callback_; + rpc::StatusCallback async_register_actor_callback_; }; } // namespace gcs diff --git a/src/ray/common/BUILD.bazel b/src/ray/common/BUILD.bazel index f36c260daf62..ef2ee3e46da1 100644 --- a/src/ray/common/BUILD.bazel +++ b/src/ray/common/BUILD.bazel @@ -390,14 +390,6 @@ ray_cc_library( ], ) -ray_cc_library( - name = "gcs_callback_types", - hdrs = ["gcs_callback_types.h"], - deps = [ - "//src/ray/common:status", - ], -) - ray_cc_library( name = "metrics", hdrs = ["metrics.h"], diff --git a/src/ray/common/gcs_callback_types.h b/src/ray/common/gcs_callback_types.h deleted file mode 100644 index 1d5da52fec9b..000000000000 --- a/src/ray/common/gcs_callback_types.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "ray/common/status.h" - -namespace ray { -namespace gcs { - -/// This callback is used to notify when a write/subscribe to GCS completes. -/// \param status Status indicates whether the write/subscribe was successful. -using StatusCallback = std::function; - -/// This callback is used to receive one item from GCS when a read completes. -/// \param status Status indicates whether the read was successful. -/// \param result The item returned by GCS. If the item to read doesn't exist, -/// this optional object is empty. -template -using OptionalItemCallback = - std::function result)>; - -/// This callback is used to receive multiple items from GCS when a read completes. -/// \param status Status indicates whether the read was successful. -/// \param result The items returned by GCS. -template -using MultiItemCallback = std::function result)>; - -/// This callback is used to receive notifications of the subscribed items in the GCS. -/// \param id The id of the item. -/// \param result The notification message. -template -using SubscribeCallback = std::function; - -/// This callback is used to receive a single item from GCS. -/// \param result The item returned by GCS. -template -using ItemCallback = std::function; - -} // namespace gcs -} // namespace ray diff --git a/src/ray/core_worker/actor_creator.cc b/src/ray/core_worker/actor_creator.cc index b5d9e10c99a3..bfa57af1548e 100644 --- a/src/ray/core_worker/actor_creator.cc +++ b/src/ray/core_worker/actor_creator.cc @@ -33,14 +33,14 @@ Status ActorCreator::RegisterActor(const TaskSpecification &task_spec) const { } void ActorCreator::AsyncRegisterActor(const TaskSpecification &task_spec, - gcs::StatusCallback callback) { + rpc::StatusCallback callback) { auto actor_id = task_spec.ActorCreationId(); (*registering_actors_)[actor_id] = {}; if (callback != nullptr) { (*registering_actors_)[actor_id].emplace_back(std::move(callback)); } actor_client_.AsyncRegisterActor(task_spec, [actor_id, this](Status status) { - std::vector cbs; + std::vector cbs; cbs = std::move((*registering_actors_)[actor_id]); registering_actors_->erase(actor_id); for (auto &cb : cbs) { @@ -52,7 +52,7 @@ void ActorCreator::AsyncRegisterActor(const TaskSpecification &task_spec, void ActorCreator::AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - gcs::StatusCallback callback) { + rpc::StatusCallback callback) { actor_client_.AsyncRestartActorForLineageReconstruction( actor_id, num_restarts_due_to_lineage_reconstructions, callback); } @@ -60,7 +60,7 @@ void ActorCreator::AsyncRestartActorForLineageReconstruction( void ActorCreator::AsyncReportActorOutOfScope( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - gcs::StatusCallback callback) { + rpc::StatusCallback callback) { actor_client_.AsyncReportActorOutOfScope( actor_id, num_restarts_due_to_lineage_reconstruction, callback); } @@ -70,7 +70,7 @@ bool ActorCreator::IsActorInRegistering(const ActorID &actor_id) const { } void ActorCreator::AsyncWaitForActorRegisterFinish(const ActorID &actor_id, - gcs::StatusCallback callback) { + rpc::StatusCallback callback) { auto iter = registering_actors_->find(actor_id); RAY_CHECK(iter != registering_actors_->end()); iter->second.emplace_back(std::move(callback)); diff --git a/src/ray/core_worker/actor_creator.h b/src/ray/core_worker/actor_creator.h index 974d1360be9f..5a40d96a62f4 100644 --- a/src/ray/core_worker/actor_creator.h +++ b/src/ray/core_worker/actor_creator.h @@ -38,17 +38,17 @@ class ActorCreatorInterface { /// \param callback Callback that will be called after the actor info is registered to /// GCS virtual void AsyncRegisterActor(const TaskSpecification &task_spec, - gcs::StatusCallback callback) = 0; + rpc::StatusCallback callback) = 0; virtual void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - gcs::StatusCallback callback) = 0; + rpc::StatusCallback callback) = 0; virtual void AsyncReportActorOutOfScope( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - gcs::StatusCallback callback) = 0; + rpc::StatusCallback callback) = 0; /// Asynchronously request GCS to create the actor. /// @@ -63,7 +63,7 @@ class ActorCreatorInterface { /// \param actor_id The actor id to wait /// \param callback The callback that will be called after actor registered virtual void AsyncWaitForActorRegisterFinish(const ActorID &actor_id, - gcs::StatusCallback callback) = 0; + rpc::StatusCallback callback) = 0; /// Check whether actor is activately under registering /// @@ -80,21 +80,21 @@ class ActorCreator : public ActorCreatorInterface { Status RegisterActor(const TaskSpecification &task_spec) const override; void AsyncRegisterActor(const TaskSpecification &task_spec, - gcs::StatusCallback callback) override; + rpc::StatusCallback callback) override; void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - gcs::StatusCallback callback) override; + rpc::StatusCallback callback) override; void AsyncReportActorOutOfScope(const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - gcs::StatusCallback callback) override; + rpc::StatusCallback callback) override; bool IsActorInRegistering(const ActorID &actor_id) const override; void AsyncWaitForActorRegisterFinish(const ActorID &actor_id, - gcs::StatusCallback callback) override; + rpc::StatusCallback callback) override; void AsyncCreateActor( const TaskSpecification &task_spec, @@ -103,7 +103,7 @@ class ActorCreator : public ActorCreatorInterface { private: gcs::ActorInfoAccessorInterface &actor_client_; using RegisteringActorType = - absl::flat_hash_map>; + absl::flat_hash_map>; ThreadPrivate registering_actors_; }; diff --git a/src/ray/core_worker/fake_actor_creator.h b/src/ray/core_worker/fake_actor_creator.h index 08deb9bf6cda..462a20da6fff 100644 --- a/src/ray/core_worker/fake_actor_creator.h +++ b/src/ray/core_worker/fake_actor_creator.h @@ -31,23 +31,23 @@ class FakeActorCreator : public ActorCreatorInterface { }; void AsyncRegisterActor(const TaskSpecification &task_spec, - gcs::StatusCallback callback) override {} + rpc::StatusCallback callback) override {} void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - gcs::StatusCallback callback) override {} + rpc::StatusCallback callback) override {} void AsyncReportActorOutOfScope(const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - gcs::StatusCallback callback) override {} + rpc::StatusCallback callback) override {} void AsyncCreateActor( const TaskSpecification &task_spec, const rpc::ClientCallback &callback) override {} void AsyncWaitForActorRegisterFinish(const ActorID &, - gcs::StatusCallback callback) override { + rpc::StatusCallback callback) override { callbacks.push_back(callback); } @@ -55,7 +55,7 @@ class FakeActorCreator : public ActorCreatorInterface { return actor_pending; } - std::list callbacks; + std::list callbacks; bool actor_pending = false; }; diff --git a/src/ray/core_worker/tests/task_event_buffer_test.cc b/src/ray/core_worker/tests/task_event_buffer_test.cc index f7127522beed..fa827827a73b 100644 --- a/src/ray/core_worker/tests/task_event_buffer_test.cc +++ b/src/ray/core_worker/tests/task_event_buffer_test.cc @@ -462,7 +462,7 @@ TEST_P(TaskEventBufferTestDifferentDestination, TestFlushEvents) { if (to_gcs) { EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData(_, _)) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { CompareTaskEventData(*actual_data, expected_task_event_data); return Status::OK(); }); @@ -518,12 +518,12 @@ TEST_P(TaskEventBufferTestDifferentDestination, TestFailedFlush) { EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData) .Times(2) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { callback(Status::RpcError("grpc error", grpc::StatusCode::UNKNOWN)); return Status::OK(); }) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { callback(Status::OK()); return Status::OK(); }); @@ -678,7 +678,7 @@ TEST_P(TaskEventBufferTestBatchSendDifferentDestination, TestBatchedSend) { EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData) .Times(num_events / batch_size) .WillRepeatedly([&batch_size](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { EXPECT_EQ(actual_data->events_by_task_size(), batch_size); callback(Status::OK()); return Status::OK(); @@ -785,7 +785,7 @@ TEST_P(TaskEventBufferTestLimitBufferDifferentDestination, if (to_gcs) { EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData(_, _)) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { // Sort and compare CompareTaskEventData(*actual_data, expected_data); return Status::OK(); @@ -860,7 +860,7 @@ TEST_F(TaskEventBufferTestLimitProfileEvents, TestBufferSizeLimitProfileEvents) EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData(_, _)) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { EXPECT_EQ(actual_data->num_profile_events_dropped(), num_profile_dropped); EXPECT_EQ(actual_data->events_by_task_size(), num_limit_profile_events); return Status::OK(); @@ -1107,7 +1107,7 @@ TEST_P(TaskEventBufferTestDifferentDestination, if (to_gcs) { EXPECT_CALL(*task_gcs_accessor, AsyncAddTaskEventData(_, _)) .WillOnce([&](std::unique_ptr actual_data, - ray::gcs::StatusCallback callback) { + ray::rpc::StatusCallback callback) { CompareTaskEventData(*actual_data, expected_task_event_data); return Status::OK(); }); diff --git a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc index 013e6c1fb1f0..df111823779c 100644 --- a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc +++ b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc @@ -107,14 +107,14 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { MOCK_METHOD(void, AsyncGetAll, - (const gcs::MultiItemCallback &, + (const rpc::MultiItemCallback &, int64_t, const std::vector &), (override)); MOCK_METHOD(void, AsyncGetAllNodeAddressAndLiveness, - (const gcs::MultiItemCallback &, + (const rpc::MultiItemCallback &, int64_t, const std::vector &), (override)); @@ -182,7 +182,7 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { [](std::vector node_info_vector) { return Invoke( [node_info_vector]( - const gcs::MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t, const std::vector &) { callback(Status::OK(), node_info_vector); @@ -283,7 +283,7 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, WorkerDeath) { AsyncGetAllNodeAddressAndLiveness(_, _, _)) .Times(2) .WillRepeatedly(Invoke( - [&](const gcs::MultiItemCallback &callback, + [&](const rpc::MultiItemCallback &callback, int64_t, const std::vector &) { callback(Status::OK(), {node_info_alive}); diff --git a/src/ray/gcs/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_placement_group_manager.cc index 0fb0e21921fe..c6bbc7a1008d 100644 --- a/src/ray/gcs/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_placement_group_manager.cc @@ -106,7 +106,8 @@ GcsPlacementGroupManager::GcsPlacementGroupManager( } void GcsPlacementGroupManager::RegisterPlacementGroup( - const std::shared_ptr &placement_group, StatusCallback callback) { + const std::shared_ptr &placement_group, + rpc::StatusCallback callback) { // NOTE: After the abnormal recovery of the network between GCS client and GCS server or // the GCS server is restarted, it is required to continue to register placement group // successfully. @@ -399,7 +400,7 @@ void GcsPlacementGroupManager::HandleRemovePlacementGroup( void GcsPlacementGroupManager::RemovePlacementGroup( const PlacementGroupID &placement_group_id, - StatusCallback on_placement_group_removed) { + rpc::StatusCallback on_placement_group_removed) { RAY_CHECK(on_placement_group_removed); // If the placement group has been already removed, don't do anything. auto placement_group_it = registered_placement_groups_.find(placement_group_id); @@ -599,7 +600,7 @@ void GcsPlacementGroupManager::HandleWaitPlacementGroupUntilReady( } void GcsPlacementGroupManager::WaitPlacementGroup( - const PlacementGroupID &placement_group_id, StatusCallback callback) { + const PlacementGroupID &placement_group_id, rpc::StatusCallback callback) { // If the placement group does not exist or it has been successfully created, return // directly. const auto &iter = registered_placement_groups_.find(placement_group_id); diff --git a/src/ray/gcs/gcs_placement_group_manager.h b/src/ray/gcs/gcs_placement_group_manager.h index cbd08ad84bac..d618b4b0d72b 100644 --- a/src/ray/gcs/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_placement_group_manager.h @@ -101,7 +101,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoGcsServiceHandler /// \param callback Will be invoked after the placement group is created successfully or /// be invoked if the placement group is deleted before create successfully. void WaitPlacementGroup(const PlacementGroupID &placement_group_id, - StatusCallback callback); + rpc::StatusCallback callback); /// Register placement_group asynchronously. /// @@ -111,7 +111,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoGcsServiceHandler /// `registered_placement_groups_` and its state is `CREATED`. The callback will not be /// called in this case. void RegisterPlacementGroup(const std::shared_ptr &placement_group, - StatusCallback callback); + rpc::StatusCallback callback); /// Schedule placement_groups in the `pending_placement_groups_` queue. /// The method handles all states of placement groups @@ -144,7 +144,7 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoGcsServiceHandler /// Remove the placement group of a given id. void RemovePlacementGroup(const PlacementGroupID &placement_group_id, - StatusCallback on_placement_group_removed); + rpc::StatusCallback on_placement_group_removed); /// Handle a node death. This will reschedule all bundles associated with the /// specified node id. @@ -287,11 +287,11 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoGcsServiceHandler /// Callbacks of pending `RegisterPlacementGroup` requests. /// Maps placement group ID to placement group registration callbacks, which is used to /// filter duplicated messages from a driver/worker caused by some network problems. - absl::flat_hash_map> + absl::flat_hash_map> placement_group_to_register_callbacks_; /// Callback of `WaitPlacementGroupUntilReady` requests. - absl::flat_hash_map> + absl::flat_hash_map> placement_group_to_create_callbacks_; /// All registered placement_groups (pending placement_groups are also included). diff --git a/src/ray/gcs/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_placement_group_scheduler.cc index d77d59f1eaf1..4e2fb86d2311 100644 --- a/src/ray/gcs/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_placement_group_scheduler.cc @@ -178,7 +178,7 @@ void GcsPlacementGroupScheduler::MarkScheduleCancelled( void GcsPlacementGroupScheduler::PrepareResources( const std::vector> &bundles, const std::optional> &node, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { if (!node.has_value()) { callback(Status::NotFound("Node is already dead.")); return; @@ -209,7 +209,7 @@ void GcsPlacementGroupScheduler::PrepareResources( void GcsPlacementGroupScheduler::CommitResources( const std::vector> &bundles, const std::optional> &node, - const StatusCallback callback) { + const rpc::StatusCallback callback) { RAY_CHECK(node.has_value()); const auto raylet_client = GetRayletClientFromNode(node.value()); const auto node_id = NodeID::FromBinary(node.value()->node_id()); diff --git a/src/ray/gcs/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_placement_group_scheduler.h index 8a8ac25bb8c3..9f509cc01561 100644 --- a/src/ray/gcs/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_placement_group_scheduler.h @@ -374,7 +374,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { void PrepareResources( const std::vector> &bundles, const std::optional> &node, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Send bundles COMMIT request to a node. This means the placement group creation /// is ready and GCS will commit resources on a given node. @@ -385,7 +385,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { void CommitResources( const std::vector> &bundles, const std::optional> &node, - const StatusCallback callback); + const rpc::StatusCallback callback); /// Cacnel prepared or committed resources from a node. /// Nodes will be in charge of tracking state of a bundle. diff --git a/src/ray/gcs/store_client/BUILD.bazel b/src/ray/gcs/store_client/BUILD.bazel index 52d4194298c2..d5239e97868e 100644 --- a/src/ray/gcs/store_client/BUILD.bazel +++ b/src/ray/gcs/store_client/BUILD.bazel @@ -5,10 +5,10 @@ ray_cc_library( hdrs = ["store_client.h"], deps = [ "//src/ray/common:asio", - "//src/ray/common:gcs_callback_types", "//src/ray/common:id", "//src/ray/common:status", "//src/ray/gcs/postable", + "//src/ray/rpc:rpc_callback_types", ], ) diff --git a/src/ray/gcs/store_client/in_memory_store_client.cc b/src/ray/gcs/store_client/in_memory_store_client.cc index cea449dd71d5..54b8ee30e44a 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.cc +++ b/src/ray/gcs/store_client/in_memory_store_client.cc @@ -38,7 +38,7 @@ void InMemoryStoreClient::AsyncPut(const std::string &table_name, void InMemoryStoreClient::AsyncGet( const std::string &table_name, const std::string &key, - ToPostable> callback) { + ToPostable> callback) { auto table = GetTable(table_name); std::optional data; if (table != nullptr) { diff --git a/src/ray/gcs/store_client/in_memory_store_client.h b/src/ray/gcs/store_client/in_memory_store_client.h index d956ad40752c..dee1a872b332 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.h +++ b/src/ray/gcs/store_client/in_memory_store_client.h @@ -41,7 +41,7 @@ class InMemoryStoreClient : public StoreClient { void AsyncGet(const std::string &table_name, const std::string &key, - ToPostable> callback) override; + ToPostable> callback) override; void AsyncGetAll( const std::string &table_name, diff --git a/src/ray/gcs/store_client/observable_store_client.cc b/src/ray/gcs/store_client/observable_store_client.cc index e8c1ead6088f..66b9c6734eb2 100644 --- a/src/ray/gcs/store_client/observable_store_client.cc +++ b/src/ray/gcs/store_client/observable_store_client.cc @@ -46,7 +46,7 @@ void ObservableStoreClient::AsyncPut(const std::string &table_name, void ObservableStoreClient::AsyncGet( const std::string &table_name, const std::string &key, - ToPostable> callback) { + ToPostable> callback) { auto start = absl::GetCurrentTimeNanos(); storage_operation_count_counter_.Record(1, {{"Operation", "Get"}}); delegate_->AsyncGet(table_name, key, std::move(callback).OnInvocation([this, start]() { diff --git a/src/ray/gcs/store_client/observable_store_client.h b/src/ray/gcs/store_client/observable_store_client.h index 127f8e9db2ad..5d2bf6d77680 100644 --- a/src/ray/gcs/store_client/observable_store_client.h +++ b/src/ray/gcs/store_client/observable_store_client.h @@ -46,7 +46,7 @@ class ObservableStoreClient : public StoreClient { void AsyncGet(const std::string &table_name, const std::string &key, - ToPostable> callback) override; + ToPostable> callback) override; void AsyncGetAll( const std::string &table_name, diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 18fdf28b83d7..66f94c2be803 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -159,9 +159,10 @@ void RedisStoreClient::AsyncPut(const std::string &table_name, SendRedisCmdWithKeys({key}, std::move(command), std::move(write_callback)); } -void RedisStoreClient::AsyncGet(const std::string &table_name, - const std::string &key, - ToPostable> callback) { +void RedisStoreClient::AsyncGet( + const std::string &table_name, + const std::string &key, + ToPostable> callback) { auto redis_callback = [callback = std::move(callback)]( const std::shared_ptr &reply) mutable { std::optional result; diff --git a/src/ray/gcs/store_client/redis_store_client.h b/src/ray/gcs/store_client/redis_store_client.h index 59c77c148eb6..ad3d8751726e 100644 --- a/src/ray/gcs/store_client/redis_store_client.h +++ b/src/ray/gcs/store_client/redis_store_client.h @@ -140,7 +140,7 @@ class RedisStoreClient : public StoreClient { void AsyncGet(const std::string &table_name, const std::string &key, - ToPostable> callback) override; + ToPostable> callback) override; void AsyncGetAll( const std::string &table_name, diff --git a/src/ray/gcs/store_client/store_client.h b/src/ray/gcs/store_client/store_client.h index dbbbbb8b3657..644463f379fc 100644 --- a/src/ray/gcs/store_client/store_client.h +++ b/src/ray/gcs/store_client/store_client.h @@ -19,10 +19,10 @@ #include #include "ray/common/asio/io_service_pool.h" -#include "ray/common/gcs_callback_types.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/postable/postable.h" +#include "ray/rpc/rpc_callback_types.h" namespace ray { @@ -56,7 +56,7 @@ class StoreClient { /// \param callback returns the value or null. virtual void AsyncGet(const std::string &table_name, const std::string &key, - ToPostable> callback) = 0; + ToPostable> callback) = 0; /// Get all data from the given table asynchronously. /// diff --git a/src/ray/gcs/tests/gcs_placement_group_manager_test.cc b/src/ray/gcs/tests/gcs_placement_group_manager_test.cc index 1d4a08cd451f..ad1d5e596a39 100644 --- a/src/ray/gcs/tests/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/tests/gcs_placement_group_manager_test.cc @@ -33,7 +33,6 @@ namespace ray { namespace gcs { using ::testing::_; -using StatusCallback = std::function; class MockPlacementGroupScheduler : public gcs::GcsPlacementGroupSchedulerInterface { public: @@ -113,7 +112,7 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { // Make placement group registration sync. void RegisterPlacementGroup(const ray::rpc::CreatePlacementGroupRequest &request, - StatusCallback callback) { + rpc::StatusCallback callback) { std::promise promise; JobID job_id = JobID::FromBinary(request.placement_group_spec().creator_job_id()); std::string ray_namespace = job_namespace_table_[job_id]; diff --git a/src/ray/gcs_rpc_client/BUILD.bazel b/src/ray/gcs_rpc_client/BUILD.bazel index 8665009c0e64..ddd4c51b3372 100644 --- a/src/ray/gcs_rpc_client/BUILD.bazel +++ b/src/ray/gcs_rpc_client/BUILD.bazel @@ -56,12 +56,12 @@ ray_cc_library( ], visibility = ["//visibility:public"], deps = [ - "//src/ray/common:gcs_callback_types", "//src/ray/common:id", "//src/ray/common:placement_group", "//src/ray/common:task_common", "//src/ray/protobuf:autoscaler_cc_proto", "//src/ray/protobuf:gcs_cc_proto", + "//src/ray/rpc:rpc_callback_types", ], ) diff --git a/src/ray/gcs_rpc_client/accessor.cc b/src/ray/gcs_rpc_client/accessor.cc index 68e85dc0e496..04880397d40c 100644 --- a/src/ray/gcs_rpc_client/accessor.cc +++ b/src/ray/gcs_rpc_client/accessor.cc @@ -31,7 +31,7 @@ namespace gcs { JobInfoAccessor::JobInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void JobInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { JobID job_id = JobID::FromBinary(data_ptr->job_id()); RAY_LOG(DEBUG).WithField(job_id) << "Adding job, driver pid = " << data_ptr->driver_pid(); @@ -49,7 +49,7 @@ void JobInfoAccessor::AsyncAdd(const std::shared_ptr &data_pt } void JobInfoAccessor::AsyncMarkFinished(const JobID &job_id, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { RAY_LOG(DEBUG).WithField(job_id) << "Marking job state"; rpc::MarkJobFinishedRequest request; request.set_job_id(job_id.Binary()); @@ -65,10 +65,11 @@ void JobInfoAccessor::AsyncMarkFinished(const JobID &job_id, } void JobInfoAccessor::AsyncSubscribeAll( - const SubscribeCallback &subscribe, - const StatusCallback &done) { + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) { RAY_CHECK(subscribe != nullptr); - fetch_all_data_operation_ = [this, subscribe](const StatusCallback &done_callback) { + fetch_all_data_operation_ = [this, + subscribe](const rpc::StatusCallback &done_callback) { auto callback = [subscribe, done_callback]( const Status &status, std::vector &&job_info_list) { @@ -85,7 +86,7 @@ void JobInfoAccessor::AsyncSubscribeAll( callback, /*timeout_ms=*/-1); }; - subscribe_operation_ = [this, subscribe](const StatusCallback &done_callback) { + subscribe_operation_ = [this, subscribe](const rpc::StatusCallback &done_callback) { client_impl_->GetGcsSubscriber().SubscribeAllJobs(subscribe, done_callback); }; subscribe_operation_( @@ -106,11 +107,12 @@ void JobInfoAccessor::AsyncResubscribe() { } } -void JobInfoAccessor::AsyncGetAll(const std::optional &job_or_submission_id, - bool skip_submission_job_info_field, - bool skip_is_running_tasks_field, - const MultiItemCallback &callback, - int64_t timeout_ms) { +void JobInfoAccessor::AsyncGetAll( + const std::optional &job_or_submission_id, + bool skip_submission_job_info_field, + bool skip_is_running_tasks_field, + const rpc::MultiItemCallback &callback, + int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting all job info."; RAY_CHECK(callback); rpc::GetAllJobInfoRequest request; @@ -146,7 +148,7 @@ Status JobInfoAccessor::GetAll(const std::optional &job_or_submissi return Status::OK(); } -void JobInfoAccessor::AsyncGetNextJobID(const ItemCallback &callback) { +void JobInfoAccessor::AsyncGetNextJobID(const rpc::ItemCallback &callback) { RAY_LOG(DEBUG) << "Getting next job id"; rpc::GetNextJobIDRequest request; client_impl_->GetGcsRpcClient().GetNextJobID( @@ -162,7 +164,7 @@ void JobInfoAccessor::AsyncGetNextJobID(const ItemCallback &callback) { NodeInfoAccessor::NodeInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void NodeInfoAccessor::RegisterSelf(rpc::GcsNodeInfo &&local_node_info, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { auto node_id = NodeID::FromBinary(local_node_info.node_id()); RAY_LOG(DEBUG).WithField(node_id) << "Registering node info, address is = " << local_node_info.node_manager_address(); @@ -198,7 +200,7 @@ void NodeInfoAccessor::UnregisterSelf(const NodeID &node_id, } void NodeInfoAccessor::AsyncRegister(const rpc::GcsNodeInfo &node_info, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { NodeID node_id = NodeID::FromBinary(node_info.node_id()); RAY_LOG(DEBUG).WithField(node_id) << "Registering node info"; rpc::RegisterNodeRequest request; @@ -216,7 +218,7 @@ void NodeInfoAccessor::AsyncRegister(const rpc::GcsNodeInfo &node_info, void NodeInfoAccessor::AsyncCheckAlive(const std::vector &node_ids, int64_t timeout_ms, - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { rpc::CheckAliveRequest request; for (const auto &node_id : node_ids) { request.add_node_ids(node_id.Binary()); @@ -260,7 +262,7 @@ Status NodeInfoAccessor::DrainNodes(const std::vector &node_ids, } void NodeInfoAccessor::AsyncGetAllNodeAddressAndLiveness( - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms, const std::vector &node_ids) { rpc::GetAllNodeAddressAndLivenessRequest request; @@ -277,9 +279,10 @@ void NodeInfoAccessor::AsyncGetAllNodeAddressAndLiveness( timeout_ms); } -void NodeInfoAccessor::AsyncGetAll(const MultiItemCallback &callback, - int64_t timeout_ms, - const std::vector &node_ids) { +void NodeInfoAccessor::AsyncGetAll( + const rpc::MultiItemCallback &callback, + int64_t timeout_ms, + const std::vector &node_ids) { RAY_LOG(DEBUG) << "Getting information of all nodes."; rpc::GetAllNodeInfoRequest request; for (const auto &node_id : node_ids) { @@ -297,7 +300,7 @@ void NodeInfoAccessor::AsyncGetAll(const MultiItemCallback &ca void NodeInfoAccessor::AsyncSubscribeToNodeAddressAndLivenessChange( std::function subscribe, - StatusCallback done) { + rpc::StatusCallback done) { /** 1. Subscribe to node info 2. Once the subscription is made, ask for all node info. @@ -313,7 +316,7 @@ void NodeInfoAccessor::AsyncSubscribeToNodeAddressAndLivenessChange( RAY_CHECK(node_change_callback_address_and_liveness_ != nullptr); fetch_node_address_and_liveness_data_operation_ = - [this](const StatusCallback &done_callback) { + [this](const rpc::StatusCallback &done_callback) { AsyncGetAllNodeAddressAndLiveness( [this, done_callback]( const Status &status, @@ -475,7 +478,7 @@ NodeResourceInfoAccessor::NodeResourceInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void NodeResourceInfoAccessor::AsyncGetAllAvailableResources( - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { rpc::GetAllAvailableResourcesRequest request; client_impl_->GetGcsRpcClient().GetAllAvailableResources( std::move(request), @@ -487,7 +490,7 @@ void NodeResourceInfoAccessor::AsyncGetAllAvailableResources( } void NodeResourceInfoAccessor::AsyncGetAllTotalResources( - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { rpc::GetAllTotalResourcesRequest request; client_impl_->GetGcsRpcClient().GetAllTotalResources( std::move(request), @@ -499,7 +502,7 @@ void NodeResourceInfoAccessor::AsyncGetAllTotalResources( } void NodeResourceInfoAccessor::AsyncGetDrainingNodes( - const ItemCallback> &callback) { + const rpc::ItemCallback> &callback) { rpc::GetDrainingNodesRequest request; client_impl_->GetGcsRpcClient().GetDrainingNodes( std::move(request), @@ -515,7 +518,7 @@ void NodeResourceInfoAccessor::AsyncGetDrainingNodes( } void NodeResourceInfoAccessor::AsyncGetAllResourceUsage( - const ItemCallback &callback) { + const rpc::ItemCallback &callback) { rpc::GetAllResourceUsageRequest request; client_impl_->GetGcsRpcClient().GetAllResourceUsage( std::move(request), @@ -534,7 +537,7 @@ Status NodeResourceInfoAccessor::GetAllResourceUsage( } void TaskInfoAccessor::AsyncAddTaskEventData(std::unique_ptr data_ptr, - StatusCallback callback) { + rpc::StatusCallback callback) { rpc::AddTaskEventDataRequest request; // Prevent copy here request.mutable_data()->Swap(data_ptr.get()); @@ -549,7 +552,7 @@ void TaskInfoAccessor::AsyncAddTaskEventData(std::unique_ptr } void TaskInfoAccessor::AsyncGetTaskEvents( - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { RAY_LOG(DEBUG) << "Getting all task events info."; RAY_CHECK(callback); rpc::GetTaskEventsRequest request; @@ -579,9 +582,10 @@ WorkerInfoAccessor::WorkerInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void WorkerInfoAccessor::AsyncSubscribeToWorkerFailures( - const ItemCallback &subscribe, const StatusCallback &done) { + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done) { RAY_CHECK(subscribe != nullptr); - subscribe_operation_ = [this, subscribe](const StatusCallback &done_callback) { + subscribe_operation_ = [this, subscribe](const rpc::StatusCallback &done_callback) { client_impl_->GetGcsSubscriber().SubscribeAllWorkerFailures(subscribe, done_callback); }; subscribe_operation_(done); @@ -599,7 +603,7 @@ void WorkerInfoAccessor::AsyncResubscribe() { void WorkerInfoAccessor::AsyncReportWorkerFailure( const std::shared_ptr &data_ptr, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { rpc::Address worker_address = data_ptr->worker_address(); RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString(); rpc::ReportWorkerFailureRequest request; @@ -618,7 +622,7 @@ void WorkerInfoAccessor::AsyncReportWorkerFailure( void WorkerInfoAccessor::AsyncGet( const WorkerID &worker_id, - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { RAY_LOG(DEBUG) << "Getting worker info, worker id = " << worker_id; rpc::GetWorkerInfoRequest request; request.set_worker_id(worker_id.Binary()); @@ -635,7 +639,7 @@ void WorkerInfoAccessor::AsyncGet( } void WorkerInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { RAY_LOG(DEBUG) << "Getting all worker info."; rpc::GetAllWorkerInfoRequest request; client_impl_->GetGcsRpcClient().GetAllWorkerInfo( @@ -648,7 +652,7 @@ void WorkerInfoAccessor::AsyncGetAll( } void WorkerInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { rpc::AddWorkerInfoRequest request; request.mutable_worker_data()->CopyFrom(*data_ptr); client_impl_->GetGcsRpcClient().AddWorkerInfo( @@ -662,7 +666,7 @@ void WorkerInfoAccessor::AsyncAdd(const std::shared_ptr &d void WorkerInfoAccessor::AsyncUpdateDebuggerPort(const WorkerID &worker_id, uint32_t debugger_port, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { rpc::UpdateWorkerDebuggerPortRequest request; request.set_worker_id(worker_id.Binary()); request.set_debugger_port(debugger_port); @@ -680,7 +684,7 @@ void WorkerInfoAccessor::AsyncUpdateDebuggerPort(const WorkerID &worker_id, void WorkerInfoAccessor::AsyncUpdateWorkerNumPausedThreads( const WorkerID &worker_id, const int num_paused_threads_delta, - const StatusCallback &callback) { + const rpc::StatusCallback &callback) { rpc::UpdateWorkerNumPausedThreadsRequest request; request.set_worker_id(worker_id.Binary()); request.set_num_paused_threads_delta(num_paused_threads_delta); @@ -727,7 +731,7 @@ Status PlacementGroupInfoAccessor::SyncRemovePlacementGroup( void PlacementGroupInfoAccessor::AsyncGet( const PlacementGroupID &placement_group_id, - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { RAY_LOG(DEBUG).WithField(placement_group_id) << "Getting placement group info"; rpc::GetPlacementGroupRequest request; request.set_placement_group_id(placement_group_id.Binary()); @@ -748,7 +752,7 @@ void PlacementGroupInfoAccessor::AsyncGet( void PlacementGroupInfoAccessor::AsyncGetByName( const std::string &name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name; rpc::GetNamedPlacementGroupRequest request; @@ -769,7 +773,7 @@ void PlacementGroupInfoAccessor::AsyncGetByName( } void PlacementGroupInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { + const rpc::MultiItemCallback &callback) { RAY_LOG(DEBUG) << "Getting all placement group info."; rpc::GetAllPlacementGroupRequest request; client_impl_->GetGcsRpcClient().GetAllPlacementGroup( @@ -804,7 +808,7 @@ void InternalKVAccessor::AsyncInternalKVGet( const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { rpc::InternalKVGetRequest req; req.set_key(key); req.set_namespace_(ns); @@ -824,7 +828,8 @@ void InternalKVAccessor::AsyncInternalKVMultiGet( const std::string &ns, const std::vector &keys, const int64_t timeout_ms, - const OptionalItemCallback> &callback) { + const rpc::OptionalItemCallback> + &callback) { rpc::InternalKVMultiGetRequest req; for (const auto &key : keys) { req.add_keys(key); @@ -849,12 +854,13 @@ void InternalKVAccessor::AsyncInternalKVMultiGet( timeout_ms); } -void InternalKVAccessor::AsyncInternalKVPut(const std::string &ns, - const std::string &key, - const std::string &value, - bool overwrite, - const int64_t timeout_ms, - const OptionalItemCallback &callback) { +void InternalKVAccessor::AsyncInternalKVPut( + const std::string &ns, + const std::string &key, + const std::string &value, + bool overwrite, + const int64_t timeout_ms, + const rpc::OptionalItemCallback &callback) { rpc::InternalKVPutRequest req; req.set_namespace_(ns); req.set_key(key); @@ -872,7 +878,7 @@ void InternalKVAccessor::AsyncInternalKVExists( const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { rpc::InternalKVExistsRequest req; req.set_namespace_(ns); req.set_key(key); @@ -884,11 +890,12 @@ void InternalKVAccessor::AsyncInternalKVExists( timeout_ms); } -void InternalKVAccessor::AsyncInternalKVDel(const std::string &ns, - const std::string &key, - bool del_by_prefix, - const int64_t timeout_ms, - const OptionalItemCallback &callback) { +void InternalKVAccessor::AsyncInternalKVDel( + const std::string &ns, + const std::string &key, + bool del_by_prefix, + const int64_t timeout_ms, + const rpc::OptionalItemCallback &callback) { rpc::InternalKVDelRequest req; req.set_namespace_(ns); req.set_key(key); @@ -905,7 +912,7 @@ void InternalKVAccessor::AsyncInternalKVKeys( const std::string &ns, const std::string &prefix, const int64_t timeout_ms, - const OptionalItemCallback> &callback) { + const rpc::OptionalItemCallback> &callback) { rpc::InternalKVKeysRequest req; req.set_namespace_(ns); req.set_prefix(prefix); @@ -1039,7 +1046,7 @@ Status InternalKVAccessor::Exists(const std::string &ns, } void InternalKVAccessor::AsyncGetInternalConfig( - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { rpc::GetInternalConfigRequest request; client_impl_->GetGcsRpcClient().GetInternalConfig( std::move(request), @@ -1133,7 +1140,7 @@ Status AutoscalerStateAccessor::GetClusterStatus(int64_t timeout_ms, void AutoscalerStateAccessor::AsyncGetClusterStatus( int64_t timeout_ms, - const OptionalItemCallback &callback) { + const rpc::OptionalItemCallback &callback) { rpc::autoscaler::GetClusterStatusRequest request; client_impl_->GetGcsRpcClient().GetClusterStatus( std::move(request), @@ -1228,7 +1235,7 @@ Status PublisherAccessor::PublishLogs(std::string key_id, void PublisherAccessor::AsyncPublishNodeResourceUsage( std::string key_id, std::string node_resource_usage_json, - const StatusCallback &done) { + const rpc::StatusCallback &done) { rpc::GcsPublishRequest request; auto *pub_message = request.add_pub_messages(); pub_message->set_channel_type(rpc::RAY_NODE_RESOURCE_USAGE_CHANNEL); diff --git a/src/ray/gcs_rpc_client/accessor.h b/src/ray/gcs_rpc_client/accessor.h index 0a5a984b11ca..8a9a321c27b2 100644 --- a/src/ray/gcs_rpc_client/accessor.h +++ b/src/ray/gcs_rpc_client/accessor.h @@ -20,7 +20,6 @@ #include #include "absl/synchronization/mutex.h" -#include "ray/common/gcs_callback_types.h" #include "ray/common/id.h" #include "ray/common/placement_group.h" #include "ray/common/status_or.h" @@ -34,8 +33,8 @@ namespace ray { namespace gcs { -using SubscribeOperation = std::function; -using FetchDataOperation = std::function; +using SubscribeOperation = std::function; +using FetchDataOperation = std::function; class GcsClient; /// \class JobInfoAccessor @@ -53,21 +52,22 @@ class JobInfoAccessor { /// \param callback Callback that will be called after job has been added /// to GCS. virtual void AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Mark job as finished in GCS asynchronously. /// /// \param job_id ID of the job that will be make finished to GCS. /// \param callback Callback that will be called after update finished. - virtual void AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback); + virtual void AsyncMarkFinished(const JobID &job_id, + const rpc::StatusCallback &callback); /// Subscribe to job updates. /// /// \param subscribe Callback that will be called each time when a job updates. /// \param done Callback that will be called when subscription is complete. virtual void AsyncSubscribeAll( - const SubscribeCallback &subscribe, - const StatusCallback &done); + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done); /// Get all job info from GCS asynchronously. /// @@ -76,7 +76,7 @@ class JobInfoAccessor { virtual void AsyncGetAll(const std::optional &job_or_submission_id, bool skip_submission_job_info_field, bool skip_is_running_tasks_field, - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms); /// Get all job info from GCS synchronously. @@ -101,7 +101,7 @@ class JobInfoAccessor { /// Increment and get next job id. This is not idempotent. /// /// \param done Callback that will be called when request successfully. - virtual void AsyncGetNextJobID(const ItemCallback &callback); + virtual void AsyncGetNextJobID(const rpc::ItemCallback &callback); private: /// Save the fetch data operation in this function, so we can call it again when GCS @@ -130,7 +130,7 @@ class NodeInfoAccessor { /// \param node_info The information of node to register to GCS. /// \param callback Callback that will be called when registration is complete. virtual void RegisterSelf(rpc::GcsNodeInfo &&local_node_info, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Unregister local node to GCS asynchronously. /// @@ -147,7 +147,7 @@ class NodeInfoAccessor { /// \param node_info The information of node to register to GCS. /// \param callback Callback that will be called when registration is complete. virtual void AsyncRegister(const rpc::GcsNodeInfo &node_info, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Send a check alive request to GCS for the liveness of some nodes. /// @@ -155,7 +155,7 @@ class NodeInfoAccessor { /// \param timeout_ms The timeout for this request. virtual void AsyncCheckAlive(const std::vector &node_ids, int64_t timeout_ms, - const MultiItemCallback &callback); + const rpc::MultiItemCallback &callback); /// Get information of all nodes from GCS asynchronously. /// @@ -163,12 +163,12 @@ class NodeInfoAccessor { /// \param timeout_ms The timeout for this request. /// \param node_ids If this is not empty, only return the node info of the specified /// nodes. - virtual void AsyncGetAll(const MultiItemCallback &callback, + virtual void AsyncGetAll(const rpc::MultiItemCallback &callback, int64_t timeout_ms, const std::vector &node_ids = {}); virtual void AsyncGetAllNodeAddressAndLiveness( - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms, const std::vector &node_ids = {}); @@ -212,7 +212,7 @@ class NodeInfoAccessor { /// \param done Callback that will be called when subscription is complete. virtual void AsyncSubscribeToNodeAddressAndLivenessChange( std::function subscribe, - StatusCallback done); + rpc::StatusCallback done); /// Send a check alive request to GCS for the liveness of some nodes. /// @@ -307,25 +307,25 @@ class NodeResourceInfoAccessor { /// /// \param callback Callback that will be called after lookup finishes. virtual void AsyncGetAllAvailableResources( - const MultiItemCallback &callback); + const rpc::MultiItemCallback &callback); /// Get total resources of all nodes from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finishes. virtual void AsyncGetAllTotalResources( - const MultiItemCallback &callback); + const rpc::MultiItemCallback &callback); /// Get draining nodes from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finishes. virtual void AsyncGetDrainingNodes( - const ItemCallback> &callback); + const rpc::ItemCallback> &callback); /// Get newest resource usage of all nodes from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finishes. virtual void AsyncGetAllResourceUsage( - const ItemCallback &callback); + const rpc::ItemCallback &callback); /// Get newest resource usage of all nodes from GCS synchronously. /// @@ -379,12 +379,13 @@ class TaskInfoAccessor { /// \param data_ptr The task states event data that will be added to GCS. /// \param callback Callback that will be called when add is complete. virtual void AsyncAddTaskEventData(std::unique_ptr data_ptr, - StatusCallback callback); + rpc::StatusCallback callback); /// Get all info/events of all tasks stored in GCS asynchronously. /// /// \param callback Callback that will be called after lookup finishes. - virtual void AsyncGetTaskEvents(const MultiItemCallback &callback); + virtual void AsyncGetTaskEvents( + const rpc::MultiItemCallback &callback); private: GcsClient *client_impl_; @@ -406,7 +407,8 @@ class WorkerInfoAccessor { /// \param subscribe Callback that will be called each time when a worker failed. /// \param done Callback that will be called when subscription is complete. virtual void AsyncSubscribeToWorkerFailures( - const ItemCallback &subscribe, const StatusCallback &done); + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done); /// Report a worker failure to GCS asynchronously. /// @@ -414,19 +416,19 @@ class WorkerInfoAccessor { /// \param callback Callback that will be called when report is complate. virtual void AsyncReportWorkerFailure( const std::shared_ptr &data_ptr, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Get worker specification from GCS asynchronously. /// /// \param worker_id The ID of worker to look up in the GCS. /// \param callback Callback that will be called after lookup finishes. virtual void AsyncGet(const WorkerID &worker_id, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); /// Get all worker info from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finished. - virtual void AsyncGetAll(const MultiItemCallback &callback); + virtual void AsyncGetAll(const rpc::MultiItemCallback &callback); /// Add worker information to GCS asynchronously. /// @@ -434,7 +436,7 @@ class WorkerInfoAccessor { /// \param callback Callback that will be called after worker information has been added /// to GCS. virtual void AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Update the worker debugger port in GCS asynchronously. /// @@ -443,7 +445,7 @@ class WorkerInfoAccessor { /// \param callback Callback that will be called after update finishes. virtual void AsyncUpdateDebuggerPort(const WorkerID &worker_id, uint32_t debugger_port, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Update the number of worker's paused threads in GCS asynchronously. /// @@ -452,7 +454,7 @@ class WorkerInfoAccessor { /// \param callback Callback that will be called after update finishes. virtual void AsyncUpdateWorkerNumPausedThreads(const WorkerID &worker_id, int num_paused_threads_delta, - const StatusCallback &callback); + const rpc::StatusCallback &callback); /// Reestablish subscription. /// This should be called when GCS server restarts from a failure. /// PubSub server restart will cause GCS server restart. In this case, we need to @@ -489,7 +491,7 @@ class PlacementGroupInfoAccessor { /// \param placement_group_id The id of a placement group to obtain from GCS. virtual void AsyncGet( const PlacementGroupID &placement_group_id, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); /// Get a placement group data from GCS asynchronously by name. /// @@ -500,14 +502,14 @@ class PlacementGroupInfoAccessor { virtual void AsyncGetByName( const std::string &placement_group_name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms = -1); /// Get all placement group info from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finished. virtual void AsyncGetAll( - const MultiItemCallback &callback); + const rpc::MultiItemCallback &callback); /// Remove a placement group to GCS synchronously. /// @@ -547,7 +549,7 @@ class InternalKVAccessor { const std::string &ns, const std::string &prefix, const int64_t timeout_ms, - const OptionalItemCallback> &callback); + const rpc::OptionalItemCallback> &callback); /// Asynchronously get the value for a given key. /// @@ -558,7 +560,7 @@ class InternalKVAccessor { virtual void AsyncInternalKVGet(const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); /// Asynchronously get the value for multiple keys. /// @@ -570,7 +572,8 @@ class InternalKVAccessor { const std::string &ns, const std::vector &keys, const int64_t timeout_ms, - const OptionalItemCallback> &callback); + const rpc::OptionalItemCallback> + &callback); /// Asynchronously set the value for a given key. /// @@ -584,7 +587,7 @@ class InternalKVAccessor { const std::string &value, bool overwrite, const int64_t timeout_ms, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); /// Asynchronously check the existence of a given key /// @@ -596,7 +599,7 @@ class InternalKVAccessor { virtual void AsyncInternalKVExists(const std::string &ns, const std::string &key, const int64_t timeout_ms, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); /// Asynchronously delete a key /// @@ -610,7 +613,7 @@ class InternalKVAccessor { const std::string &key, bool del_by_prefix, const int64_t timeout_ms, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); // These are sync functions of the async above @@ -708,7 +711,8 @@ class InternalKVAccessor { /// Get the internal config string from GCS. /// /// \param callback Processes a map of config options - virtual void AsyncGetInternalConfig(const OptionalItemCallback &callback); + virtual void AsyncGetInternalConfig( + const rpc::OptionalItemCallback &callback); private: GcsClient *client_impl_; @@ -755,7 +759,7 @@ class AutoscalerStateAccessor { virtual void AsyncGetClusterStatus( int64_t timeout_ms, - const OptionalItemCallback &callback); + const rpc::OptionalItemCallback &callback); virtual Status ReportAutoscalingState(int64_t timeout_ms, const std::string &serialized_state); @@ -793,7 +797,7 @@ class PublisherAccessor { virtual void AsyncPublishNodeResourceUsage(std::string key_id, std::string node_resource_usage_json, - const StatusCallback &done); + const rpc::StatusCallback &done); private: GcsClient *client_impl_; diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc index 4d5b3bbff49c..4ad033638b54 100644 --- a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc @@ -31,7 +31,8 @@ namespace gcs { ActorInfoAccessor::ActorInfoAccessor(GcsClientContext *context) : context_(context) {} void ActorInfoAccessor::AsyncGet( - const ActorID &actor_id, const OptionalItemCallback &callback) { + const ActorID &actor_id, + const rpc::OptionalItemCallback &callback) { RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) << "Getting actor info"; rpc::GetActorInfoRequest request; request.set_actor_id(actor_id.Binary()); @@ -52,7 +53,7 @@ void ActorInfoAccessor::AsyncGetAllByFilter( const std::optional &actor_id, const std::optional &job_id, const std::optional &actor_state_name, - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting all actor info."; rpc::GetAllActorInfoRequest request; @@ -86,7 +87,7 @@ void ActorInfoAccessor::AsyncGetAllByFilter( void ActorInfoAccessor::AsyncGetByName( const std::string &name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms) { RAY_LOG(DEBUG) << "Getting actor info, name = " << name; rpc::GetNamedActorInfoRequest request; @@ -148,7 +149,7 @@ Status ActorInfoAccessor::SyncListNamedActors( void ActorInfoAccessor::AsyncRestartActorForLineageReconstruction( const ray::ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - const ray::gcs::StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms) { rpc::RestartActorForLineageReconstructionRequest request; request.set_actor_id(actor_id.Binary()); @@ -180,7 +181,7 @@ Status ComputeGcsStatus(const Status &grpc_status, const rpc::GcsStatus &gcs_sta } // namespace void ActorInfoAccessor::AsyncRegisterActor(const ray::TaskSpecification &task_spec, - const ray::gcs::StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms) { RAY_CHECK(task_spec.IsActorCreationTask() && callback); rpc::RegisterActorRequest request; @@ -206,7 +207,7 @@ Status ActorInfoAccessor::SyncRegisterActor(const ray::TaskSpecification &task_s void ActorInfoAccessor::AsyncKillActor(const ActorID &actor_id, bool force_kill, bool no_restart, - const ray::gcs::StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms) { rpc::KillActorViaGcsRequest request; request.set_actor_id(actor_id.Binary()); @@ -238,7 +239,7 @@ void ActorInfoAccessor::AsyncCreateActor( void ActorInfoAccessor::AsyncReportActorOutOfScope( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms) { rpc::ReportActorOutOfScopeRequest request; request.set_actor_id(actor_id.Binary()); @@ -256,14 +257,14 @@ void ActorInfoAccessor::AsyncReportActorOutOfScope( void ActorInfoAccessor::AsyncSubscribe( const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) { + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) { RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) << "Subscribing update operations of actor"; RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; auto fetch_data_operation = - [this, actor_id, subscribe](const StatusCallback &fetch_done) { + [this, actor_id, subscribe](const rpc::StatusCallback &fetch_done) { auto callback = [actor_id, subscribe, fetch_done]( const Status &status, std::optional &&result) { @@ -280,7 +281,7 @@ void ActorInfoAccessor::AsyncSubscribe( { absl::MutexLock lock(&mutex_); resubscribe_operations_[actor_id] = [this, actor_id, subscribe]( - const StatusCallback &subscribe_done) { + const rpc::StatusCallback &subscribe_done) { context_->GetGcsSubscriber().SubscribeActor(actor_id, subscribe, subscribe_done); }; fetch_data_operations_[actor_id] = fetch_data_operation; diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h index 2572b9adb7f9..da72e212e0b9 100644 --- a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h @@ -18,7 +18,6 @@ #include #include "absl/synchronization/mutex.h" -#include "ray/common/gcs_callback_types.h" #include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" #include "ray/gcs_rpc_client/gcs_client_context.h" #include "ray/rpc/rpc_callback_types.h" @@ -27,8 +26,8 @@ namespace ray { namespace gcs { -using SubscribeOperation = std::function; -using FetchDataOperation = std::function; +using SubscribeOperation = std::function; +using FetchDataOperation = std::function; /** @class ActorInfoAccessor @@ -48,7 +47,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { @param callback Callback that will be called after lookup finishes. */ void AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) override; + const rpc::OptionalItemCallback &callback) override; /** Get all actor specification from the GCS asynchronously. @@ -62,7 +61,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { void AsyncGetAllByFilter(const std::optional &actor_id, const std::optional &job_id, const std::optional &actor_state_name, - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms = -1) override; /** @@ -75,7 +74,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { */ void AsyncGetByName(const std::string &name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms = -1) override; /** @@ -121,7 +120,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { */ void AsyncReportActorOutOfScope(const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override; /** @@ -132,7 +131,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { @param timeout_ms RPC timeout ms. -1 means there's no timeout. */ void AsyncRegisterActor(const TaskSpecification &task_spec, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override; /** @@ -147,7 +146,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override; /** @@ -174,7 +173,7 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { void AsyncKillActor(const ActorID &actor_id, bool force_kill, bool no_restart, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override; /** @@ -199,9 +198,10 @@ class ActorInfoAccessor : public ActorInfoAccessorInterface { @param subscribe Callback that will be called each time when the actor is updated. @param done Callback that will be called when subscription is complete. */ - void AsyncSubscribe(const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override; + void AsyncSubscribe( + const ActorID &actor_id, + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) override; /** Cancel subscription to an actor. diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h b/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h index 3b87368d6887..9cf3b415a923 100644 --- a/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h @@ -17,7 +17,6 @@ #include #include -#include "ray/common/gcs_callback_types.h" #include "ray/common/id.h" #include "ray/common/task/task_spec.h" #include "ray/rpc/rpc_callback_types.h" @@ -42,8 +41,9 @@ class ActorInfoAccessorInterface { @param actor_id The ID of actor to look up. @param callback Callback that will be called after lookup finishes. */ - virtual void AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) = 0; + virtual void AsyncGet( + const ActorID &actor_id, + const rpc::OptionalItemCallback &callback) = 0; /** Get all actor specifications asynchronously. @@ -54,11 +54,12 @@ class ActorInfoAccessorInterface { @param callback Callback that will be called after lookup finishes. @param timeout_ms request timeout, defaults to -1 for infinite timeout. */ - virtual void AsyncGetAllByFilter(const std::optional &actor_id, - const std::optional &job_id, - const std::optional &actor_state_name, - const MultiItemCallback &callback, - int64_t timeout_ms = -1) = 0; + virtual void AsyncGetAllByFilter( + const std::optional &actor_id, + const std::optional &job_id, + const std::optional &actor_state_name, + const rpc::MultiItemCallback &callback, + int64_t timeout_ms = -1) = 0; /** Get actor specification for a named actor asynchronously. @@ -68,10 +69,11 @@ class ActorInfoAccessorInterface { @param callback Callback that will be called after lookup finishes. @param timeout_ms RPC timeout in milliseconds. -1 means the default. */ - virtual void AsyncGetByName(const std::string &name, - const std::string &ray_namespace, - const OptionalItemCallback &callback, - int64_t timeout_ms = -1) = 0; + virtual void AsyncGetByName( + const std::string &name, + const std::string &ray_namespace, + const rpc::OptionalItemCallback &callback, + int64_t timeout_ms = -1) = 0; /** Get actor specification for a named actor synchronously. @@ -113,7 +115,7 @@ class ActorInfoAccessorInterface { virtual void AsyncReportActorOutOfScope( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) = 0; /** @@ -124,7 +126,7 @@ class ActorInfoAccessorInterface { @param timeout_ms Timeout ms. -1 means there's no timeout. */ virtual void AsyncRegisterActor(const TaskSpecification &task_spec, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) = 0; /** @@ -139,7 +141,7 @@ class ActorInfoAccessorInterface { virtual void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) = 0; /** @@ -166,7 +168,7 @@ class ActorInfoAccessorInterface { virtual void AsyncKillActor(const ActorID &actor_id, bool force_kill, bool no_restart, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) = 0; /** @@ -188,8 +190,8 @@ class ActorInfoAccessorInterface { */ virtual void AsyncSubscribe( const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) = 0; + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) = 0; /** Cancel subscription to an actor. diff --git a/src/ray/gcs_rpc_client/global_state_accessor.h b/src/ray/gcs_rpc_client/global_state_accessor.h index 4bf1ee31814d..3210971c24db 100644 --- a/src/ray/gcs_rpc_client/global_state_accessor.h +++ b/src/ray/gcs_rpc_client/global_state_accessor.h @@ -232,7 +232,7 @@ class GlobalStateAccessor { /// /// \return MultiItemCallback within in rpc type DATA. template - MultiItemCallback TransformForMultiItemCallback( + rpc::MultiItemCallback TransformForMultiItemCallback( std::vector &data_vec, std::promise &promise) { return [&data_vec, &promise](const Status &status, std::vector result) { RAY_CHECK_OK(status); @@ -248,7 +248,7 @@ class GlobalStateAccessor { /// /// \return OptionalItemCallback within in rpc type DATA. template - OptionalItemCallback TransformForOptionalItemCallback( + rpc::OptionalItemCallback TransformForOptionalItemCallback( std::unique_ptr &data, std::promise &promise) { return [&data, &promise](const Status &status, const std::optional &result) { RAY_CHECK_OK(status); @@ -263,8 +263,8 @@ class GlobalStateAccessor { /// /// \return ItemCallback within in rpc type DATA. template - ItemCallback TransformForItemCallback(std::unique_ptr &data, - std::promise &promise) { + rpc::ItemCallback TransformForItemCallback(std::unique_ptr &data, + std::promise &promise) { return [&data, &promise](const DATA &result) { data.reset(new std::string(result.SerializeAsString())); promise.set_value(true); diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc index 8781d533f9c6..04b770cae620 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc @@ -83,15 +83,16 @@ class TestActorInfoAccessor : public ActorInfoAccessorInterface { bool IsFake() const { return is_fake_; } void AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) override {} + const rpc::OptionalItemCallback &callback) override { + } void AsyncGetAllByFilter(const std::optional &actor_id, const std::optional &job_id, const std::optional &actor_state_name, - const MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t timeout_ms = -1) override {} void AsyncGetByName(const std::string &name, const std::string &ray_namespace, - const OptionalItemCallback &callback, + const rpc::OptionalItemCallback &callback, int64_t timeout_ms = -1) override {} Status SyncGetByName(const std::string &name, const std::string &ray_namespace, @@ -107,15 +108,15 @@ class TestActorInfoAccessor : public ActorInfoAccessorInterface { } void AsyncReportActorOutOfScope(const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override {} void AsyncRegisterActor(const TaskSpecification &task_spec, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override {} void AsyncRestartActorForLineageReconstruction( const ActorID &actor_id, uint64_t num_restarts_due_to_lineage_reconstructions, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override {} Status SyncRegisterActor(const ray::TaskSpecification &task_spec) override { return Status::OK(); @@ -123,14 +124,15 @@ class TestActorInfoAccessor : public ActorInfoAccessorInterface { void AsyncKillActor(const ActorID &actor_id, bool force_kill, bool no_restart, - const StatusCallback &callback, + const rpc::StatusCallback &callback, int64_t timeout_ms = -1) override {} void AsyncCreateActor( const TaskSpecification &task_spec, const rpc::ClientCallback &callback) override {} - void AsyncSubscribe(const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override {} + void AsyncSubscribe( + const ActorID &actor_id, + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) override {} void AsyncResubscribe() override {} void AsyncUnsubscribe(const ActorID &actor_id) override {} bool IsActorUnsubscribed(const ActorID &actor_id) override { return false; } diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index 17713e08e6a6..80723fff792c 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -233,7 +233,7 @@ class GcsClientTest : public ::testing::TestWithParam { } bool SubscribeToAllJobs( - const gcs::SubscribeCallback &subscribe) { + const rpc::SubscribeCallback &subscribe) { std::promise promise; gcs_client_->Jobs().AsyncSubscribeAll( subscribe, [&promise](Status status) { promise.set_value(status.ok()); }); @@ -269,7 +269,7 @@ class GcsClientTest : public ::testing::TestWithParam { bool SubscribeActor( const ActorID &actor_id, - const gcs::SubscribeCallback &subscribe) { + const rpc::SubscribeCallback &subscribe) { std::promise promise; gcs_client_->Actors().AsyncSubscribe(actor_id, subscribe, [&promise](Status status) { promise.set_value(status.ok()); @@ -423,7 +423,7 @@ class GcsClientTest : public ::testing::TestWithParam { } bool SubscribeToWorkerFailures( - const gcs::ItemCallback &subscribe) { + const rpc::ItemCallback &subscribe) { std::promise promise; gcs_client_->Workers().AsyncSubscribeToWorkerFailures( subscribe, [&promise](Status status) { promise.set_value(status.ok()); }); diff --git a/src/ray/pubsub/BUILD.bazel b/src/ray/pubsub/BUILD.bazel index ca05d7c028a4..68c12678df78 100644 --- a/src/ray/pubsub/BUILD.bazel +++ b/src/ray/pubsub/BUILD.bazel @@ -87,8 +87,8 @@ ray_cc_library( hdrs = ["gcs_subscriber.h"], deps = [ ":subscriber_interface", - "//src/ray/common:gcs_callback_types", "//src/ray/protobuf:gcs_cc_proto", + "//src/ray/rpc:rpc_callback_types", ], ) diff --git a/src/ray/pubsub/gcs_subscriber.cc b/src/ray/pubsub/gcs_subscriber.cc index 10437c6864fd..ee9ece1d4d5b 100644 --- a/src/ray/pubsub/gcs_subscriber.cc +++ b/src/ray/pubsub/gcs_subscriber.cc @@ -22,8 +22,8 @@ namespace ray { namespace pubsub { void GcsSubscriber::SubscribeAllJobs( - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) { auto subscribe_item_callback = [subscribe](rpc::PubMessage &&msg) { RAY_CHECK(msg.channel_type() == rpc::ChannelType::GCS_JOB_CHANNEL); const JobID id = JobID::FromBinary(msg.key_id()); @@ -48,8 +48,8 @@ void GcsSubscriber::SubscribeAllJobs( void GcsSubscriber::SubscribeActor( const ActorID &id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done) { auto subscription_callback = [id, subscribe](rpc::PubMessage &&msg) { RAY_CHECK(msg.channel_type() == rpc::ChannelType::GCS_ACTOR_CHANNEL); RAY_CHECK(msg.key_id() == id.Binary()); @@ -86,8 +86,8 @@ bool GcsSubscriber::IsActorUnsubscribed(const ActorID &id) { } void GcsSubscriber::SubscribeAllNodeInfo( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done) { auto subscribe_item_callback = [subscribe](rpc::PubMessage &&msg) { RAY_CHECK(msg.channel_type() == rpc::ChannelType::GCS_NODE_INFO_CHANNEL); subscribe(std::move(*msg.mutable_node_info_message())); @@ -110,8 +110,8 @@ void GcsSubscriber::SubscribeAllNodeInfo( } void GcsSubscriber::SubscribeAllNodeAddressAndLiveness( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done) { auto subscribe_item_callback = [subscribe](rpc::PubMessage &&msg) { RAY_CHECK(msg.channel_type() == rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL); @@ -136,8 +136,8 @@ void GcsSubscriber::SubscribeAllNodeAddressAndLiveness( } void GcsSubscriber::SubscribeAllWorkerFailures( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done) { auto subscribe_item_callback = [subscribe](rpc::PubMessage &&msg) { RAY_CHECK(msg.channel_type() == rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL); subscribe(std::move(*msg.mutable_worker_delta_message())); diff --git a/src/ray/pubsub/gcs_subscriber.h b/src/ray/pubsub/gcs_subscriber.h index ce439b3657db..44b5df13d0ab 100644 --- a/src/ray/pubsub/gcs_subscriber.h +++ b/src/ray/pubsub/gcs_subscriber.h @@ -18,8 +18,8 @@ #include #include -#include "ray/common/gcs_callback_types.h" #include "ray/pubsub/subscriber_interface.h" +#include "ray/rpc/rpc_callback_types.h" #include "src/ray/protobuf/gcs.pb.h" namespace ray { @@ -44,25 +44,25 @@ class GcsSubscriber { /// Uses GCS pubsub when created with `subscriber`. void SubscribeActor( const ActorID &id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done); + const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done); void UnsubscribeActor(const ActorID &id); bool IsActorUnsubscribed(const ActorID &id); - void SubscribeAllJobs(const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done); + void SubscribeAllJobs(const rpc::SubscribeCallback &subscribe, + const rpc::StatusCallback &done); - void SubscribeAllNodeInfo(const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done); + void SubscribeAllNodeInfo(const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done); void SubscribeAllNodeAddressAndLiveness( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done); + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done); void SubscribeAllWorkerFailures( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done); + const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done); /// Prints debugging info for the subscriber. std::string DebugString() const; diff --git a/src/ray/raylet/tests/node_manager_test.cc b/src/ray/raylet/tests/node_manager_test.cc index a52b27c0f573..149ee01f6924 100644 --- a/src/ray/raylet/tests/node_manager_test.cc +++ b/src/ray/raylet/tests/node_manager_test.cc @@ -521,11 +521,11 @@ TEST_F(NodeManagerTest, TestDetachedWorkerIsKilledByFailedWorker) { }); // Save the publish_worker_failure_callback for publishing a worker failure event later. - gcs::ItemCallback publish_worker_failure_callback; + rpc::ItemCallback publish_worker_failure_callback; EXPECT_CALL(*mock_gcs_client_->mock_worker_accessor, AsyncSubscribeToWorkerFailures(_, _)) - .WillOnce([&](const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done) { + .WillOnce([&](const rpc::ItemCallback &subscribe, + const rpc::StatusCallback &done) { publish_worker_failure_callback = subscribe; return Status::OK(); }); @@ -601,9 +601,9 @@ TEST_F(NodeManagerTest, TestDetachedWorkerIsKilledByFailedNode) { publish_node_change_callback; EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, AsyncSubscribeToNodeAddressAndLivenessChange(_, _)) - .WillOnce([&](const gcs::SubscribeCallback + .WillOnce([&](const rpc::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::StatusCallback &done) { publish_node_change_callback = subscribe; }); node_manager_->RegisterGcs(); @@ -1325,13 +1325,13 @@ TEST_P(NodeManagerDeathTest, TestGcsPublishesSelfDead) { // started const bool shutting_down_during_death_publish = GetParam(); - gcs::SubscribeCallback + rpc::SubscribeCallback publish_node_change_callback; EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, AsyncSubscribeToNodeAddressAndLivenessChange(_, _)) - .WillOnce([&](const gcs::SubscribeCallback + .WillOnce([&](const rpc::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) { + const rpc::StatusCallback &done) { publish_node_change_callback = subscribe; }); node_manager_->RegisterGcs(); diff --git a/src/ray/raylet_rpc_client/BUILD.bazel b/src/ray/raylet_rpc_client/BUILD.bazel index f3208a4a9f27..99bdb9052937 100644 --- a/src/ray/raylet_rpc_client/BUILD.bazel +++ b/src/ray/raylet_rpc_client/BUILD.bazel @@ -35,7 +35,6 @@ ray_cc_library( deps = [ ":raylet_client_interface", "//src/ray/common:bundle_spec", - "//src/ray/common:gcs_callback_types", "//src/ray/common:ray_config", "//src/ray/protobuf:node_manager_cc_grpc", "//src/ray/rpc:retryable_grpc_client", diff --git a/src/ray/raylet_rpc_client/raylet_client.cc b/src/ray/raylet_rpc_client/raylet_client.cc index d76b49716331..210baaf6ad86 100644 --- a/src/ray/raylet_rpc_client/raylet_client.cc +++ b/src/ray/raylet_rpc_client/raylet_client.cc @@ -476,7 +476,7 @@ void RayletClient::GetNodeStats( } void RayletClient::GetWorkerPIDs( - const gcs::OptionalItemCallback> &callback, int64_t timeout_ms) { + const rpc::OptionalItemCallback> &callback, int64_t timeout_ms) { rpc::GetWorkerPIDsRequest request; auto client_callback = [callback](const Status &status, rpc::GetWorkerPIDsReply &&reply) { diff --git a/src/ray/raylet_rpc_client/raylet_client.h b/src/ray/raylet_rpc_client/raylet_client.h index 07de0073ef02..fb3289c9bb7c 100644 --- a/src/ray/raylet_rpc_client/raylet_client.h +++ b/src/ray/raylet_rpc_client/raylet_client.h @@ -22,7 +22,6 @@ #include #include -#include "ray/common/gcs_callback_types.h" #include "ray/raylet_rpc_client/raylet_client_interface.h" #include "ray/rpc/grpc_client.h" #include "ray/rpc/retryable_grpc_client.h" @@ -171,7 +170,7 @@ class RayletClient : public RayletClientInterface { /// Get the worker pids from raylet. /// \param callback The callback to set the worker pids. /// \param timeout_ms The timeout in milliseconds. - void GetWorkerPIDs(const gcs::OptionalItemCallback> &callback, + void GetWorkerPIDs(const rpc::OptionalItemCallback> &callback, int64_t timeout_ms); protected: diff --git a/src/ray/raylet_rpc_client/raylet_client_with_io_context.cc b/src/ray/raylet_rpc_client/raylet_client_with_io_context.cc index 6fbb971099f1..574318e363ea 100644 --- a/src/ray/raylet_rpc_client/raylet_client_with_io_context.cc +++ b/src/ray/raylet_rpc_client/raylet_client_with_io_context.cc @@ -19,7 +19,6 @@ #include #include -#include "ray/common/asio/asio_util.h" #include "ray/common/ray_config.h" #include "ray/util/logging.h" #include "src/ray/protobuf/node_manager.grpc.pb.h" @@ -31,8 +30,7 @@ RayletClientWithIoContext::RayletClientWithIoContext(const std::string &ip_addre int port) { // Connect to the raylet on a singleton io service with a dedicated thread. // This is to avoid creating multiple threads for multiple clients in python. - static InstrumentedIOContextWithThread io_context("raylet_client_io_service"); - instrumented_io_context &io_service = io_context.GetIoService(); + instrumented_io_context &io_service = io_context_.GetIoService(); client_call_manager_ = std::make_unique( io_service, /*record_stats=*/false, ip_address); auto raylet_unavailable_timeout_callback = []() { @@ -48,7 +46,7 @@ RayletClientWithIoContext::RayletClientWithIoContext(const std::string &ip_addre } void RayletClientWithIoContext::GetWorkerPIDs( - const gcs::OptionalItemCallback> &callback, int64_t timeout_ms) { + const rpc::OptionalItemCallback> &callback, int64_t timeout_ms) { raylet_client_->GetWorkerPIDs(callback, timeout_ms); } diff --git a/src/ray/raylet_rpc_client/raylet_client_with_io_context.h b/src/ray/raylet_rpc_client/raylet_client_with_io_context.h index 8800ea2c7e9b..ad3de9bcac14 100644 --- a/src/ray/raylet_rpc_client/raylet_client_with_io_context.h +++ b/src/ray/raylet_rpc_client/raylet_client_with_io_context.h @@ -14,6 +14,7 @@ #pragma once +#include "ray/common/asio/asio_util.h" #include "ray/raylet_rpc_client/raylet_client.h" #include "ray/rpc/grpc_client.h" @@ -35,7 +36,7 @@ class RayletClientWithIoContext { /// Get the worker pids from raylet. /// \param callback The callback to set the worker pids. /// \param timeout_ms The timeout in milliseconds. - void GetWorkerPIDs(const gcs::OptionalItemCallback> &callback, + void GetWorkerPIDs(const rpc::OptionalItemCallback> &callback, int64_t timeout_ms); private: @@ -43,6 +44,7 @@ class RayletClientWithIoContext { /// during the whole lifetime of client. std::unique_ptr client_call_manager_; std::unique_ptr raylet_client_; + inline static InstrumentedIOContextWithThread io_context_{"raylet_client_io_service"}; }; } // namespace rpc diff --git a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc index d8f6e8748e50..424de222d2c5 100644 --- a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc +++ b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc @@ -66,7 +66,7 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { MOCK_METHOD(void, AsyncGetAllNodeAddressAndLiveness, - (const gcs::MultiItemCallback &, + (const rpc::MultiItemCallback &, int64_t, const std::vector &), (override)); @@ -133,7 +133,7 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { [](std::vector node_info_vector) { return Invoke( [node_info_vector]( - const gcs::MultiItemCallback &callback, + const rpc::MultiItemCallback &callback, int64_t, const std::vector &) { callback(Status::OK(), node_info_vector); diff --git a/src/ray/rpc/rpc_callback_types.h b/src/ray/rpc/rpc_callback_types.h index 303a93b85c53..a448e050b53f 100644 --- a/src/ray/rpc/rpc_callback_types.h +++ b/src/ray/rpc/rpc_callback_types.h @@ -15,6 +15,8 @@ #pragma once #include +#include +#include #include "ray/common/status.h" @@ -37,5 +39,36 @@ using SendReplyCallback = std::function using ClientCallback = std::function; +/// This callback is used to notify when a write/subscribe to a rpc server completes. +/// \param status Status indicates whether the write/subscribe was successful. +using StatusCallback = std::function; + +/// This callback is used to receive one item from a rpc server when a read completes. +/// \param status Status indicates whether the read was successful. +/// \param result The item returned by the rpc server. If the item to read doesn't exist, +/// this optional object is empty. +template +using OptionalItemCallback = + std::function result)>; + +/// This callback is used to receive multiple items from a rpc server when a read +/// completes. +/// \param status Status indicates whether the read was successful. +/// \param result The items returned by the rpc server. +template +using MultiItemCallback = std::function result)>; + +/// This callback is used to receive notifications of the subscribed items in a rpc +/// server. +/// \param id The id of the item. +/// \param result The notification message. +template +using SubscribeCallback = std::function; + +/// This callback is used to receive a single item from a rpc server. +/// \param result The item returned by the rpc server. +template +using ItemCallback = std::function; + } // namespace rpc } // namespace ray