diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index c406a82c5c6c..93d82613268c 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -423,8 +423,8 @@ cdef extern from "ray/common/python_callbacks.h" namespace "ray": void (object, object) nogil, object) nogil -cdef extern from "ray/gcs_rpc_client/accessor.h" nogil: - cdef cppclass CActorInfoAccessor "ray::gcs::ActorInfoAccessor": +cdef extern from "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" nogil: + cdef cppclass CActorInfoAccessorInterface "ray::gcs::ActorInfoAccessorInterface": void AsyncGetAllByFilter( const optional[CActorID] &actor_id, const optional[CJobID] &job_id, @@ -438,6 +438,7 @@ cdef extern from "ray/gcs_rpc_client/accessor.h" nogil: const StatusPyCallback &callback, int64_t timeout_ms) +cdef extern from "ray/gcs_rpc_client/accessor.h" nogil: cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor": CRayStatus GetAll( const optional[c_string] &job_or_submission_id, @@ -649,7 +650,7 @@ cdef extern from "ray/gcs_rpc_client/gcs_client.h" nogil: c_pair[c_string, int] GetGcsServerAddress() const CClusterID GetClusterId() const - CActorInfoAccessor& Actors() + CActorInfoAccessorInterface& Actors() CJobInfoAccessor& Jobs() CInternalKVAccessor& InternalKV() CNodeInfoAccessor& Nodes() diff --git a/src/mock/ray/gcs_client/accessor.h b/src/mock/ray/gcs_client/accessor.h index 819f384376d9..12c2c2698fee 100644 --- a/src/mock/ray/gcs_client/accessor.h +++ b/src/mock/ray/gcs_client/accessor.h @@ -18,68 +18,6 @@ namespace ray { namespace gcs { -class MockActorInfoAccessor : public ActorInfoAccessor { - public: - MOCK_METHOD(void, - AsyncGet, - (const ActorID &actor_id, - const OptionalItemCallback &callback), - (override)); - MOCK_METHOD(void, - AsyncGetAllByFilter, - (const std::optional &actor_id, - const std::optional &job_id, - const std::optional &actor_state_name, - const MultiItemCallback &callback, - int64_t timeout_ms), - (override)); - MOCK_METHOD(void, - AsyncGetByName, - (const std::string &name, - const std::string &ray_namespace, - const OptionalItemCallback &callback, - int64_t timeout_ms), - (override)); - MOCK_METHOD(void, - AsyncRegisterActor, - (const TaskSpecification &task_spec, - const StatusCallback &callback, - int64_t timeout_ms), - (override)); - MOCK_METHOD(Status, - SyncRegisterActor, - (const TaskSpecification &task_spec), - (override)); - MOCK_METHOD(void, - AsyncKillActor, - (const ActorID &actor_id, - bool force_kill, - bool no_restart, - const StatusCallback &callback, - int64_t timeout_ms), - (override)); - MOCK_METHOD(void, - AsyncCreateActor, - (const TaskSpecification &task_spec, - const rpc::ClientCallback &callback), - (override)); - MOCK_METHOD(void, - AsyncSubscribe, - (const ActorID &actor_id, - (const SubscribeCallback &subscribe), - const StatusCallback &done), - (override)); - MOCK_METHOD(void, AsyncUnsubscribe, (const ActorID &actor_id), (override)); - MOCK_METHOD(void, AsyncResubscribe, (), (override)); - MOCK_METHOD(bool, IsActorUnsubscribed, (const ActorID &actor_id), (override)); -}; - -} // namespace gcs -} // namespace ray - -namespace ray { -namespace gcs { - class MockJobInfoAccessor : public JobInfoAccessor { public: MOCK_METHOD(void, diff --git a/src/mock/ray/gcs_client/accessors/actor_info_accessor.h b/src/mock/ray/gcs_client/accessors/actor_info_accessor.h new file mode 100644 index 000000000000..8652db89bd10 --- /dev/null +++ b/src/mock/ray/gcs_client/accessors/actor_info_accessor.h @@ -0,0 +1,138 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/container/flat_hash_map.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/common/task/task_spec.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" +#include "src/ray/protobuf/gcs.pb.h" + +namespace ray { +namespace gcs { + +class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface { + public: + FakeActorInfoAccessor() = default; + + ~FakeActorInfoAccessor() {} + + // Stub implementations for interface methods not used by this test + void AsyncGet(const ActorID &, + const gcs::OptionalItemCallback &) override {} + void AsyncGetAllByFilter(const std::optional &, + const std::optional &, + const std::optional &, + const gcs::MultiItemCallback &, + int64_t = -1) override {} + void AsyncGetByName(const std::string &, + const std::string &, + const gcs::OptionalItemCallback &, + int64_t = -1) override {} + Status SyncGetByName(const std::string &, + const std::string &, + rpc::ActorTableData &, + rpc::TaskSpec &) override { + return Status::OK(); + } + Status SyncListNamedActors( + bool, + const std::string &, + std::vector> &) override { + return Status::OK(); + } + void AsyncReportActorOutOfScope(const ActorID &, + uint64_t, + const gcs::StatusCallback &, + int64_t = -1) override {} + void AsyncRegisterActor(const TaskSpecification &task_spec, + const gcs::StatusCallback &callback, + int64_t = -1) override { + async_register_actor_callback_ = callback; + } + void AsyncRestartActorForLineageReconstruction(const ActorID &, + uint64_t, + const gcs::StatusCallback &, + int64_t = -1) override {} + Status SyncRegisterActor(const TaskSpecification &) override { return Status::OK(); } + void AsyncKillActor( + const ActorID &, bool, bool, const gcs::StatusCallback &, int64_t = -1) override {} + void AsyncCreateActor( + const TaskSpecification &task_spec, + const rpc::ClientCallback &callback) override { + async_create_actor_callback_ = callback; + } + + void AsyncSubscribe( + const ActorID &actor_id, + const gcs::SubscribeCallback &subscribe, + const gcs::StatusCallback &done) override { + auto callback_entry = std::make_pair(actor_id, subscribe); + callback_map_.emplace(actor_id, subscribe); + subscribe_finished_callback_map_[actor_id] = done; + actor_subscribed_times_[actor_id]++; + } + + void AsyncUnsubscribe(const ActorID &) override {} + void AsyncResubscribe() override {} + bool IsActorUnsubscribed(const ActorID &) override { return false; } + + bool ActorStateNotificationPublished(const ActorID &actor_id, + const rpc::ActorTableData &actor_data) { + auto it = callback_map_.find(actor_id); + if (it == callback_map_.end()) return false; + auto actor_state_notification_callback = it->second; + auto copied = actor_data; + actor_state_notification_callback(actor_id, std::move(copied)); + return true; + } + + bool CheckSubscriptionRequested(const ActorID &actor_id) { + return callback_map_.find(actor_id) != callback_map_.end(); + } + + // Mock the logic of subscribe finished. see `ActorInfoAccessor::AsyncSubscribe` + bool ActorSubscribeFinished(const ActorID &actor_id, + const rpc::ActorTableData &actor_data) { + auto subscribe_finished_callback_it = subscribe_finished_callback_map_.find(actor_id); + if (subscribe_finished_callback_it == subscribe_finished_callback_map_.end()) { + return false; + } + + auto copied = actor_data; + if (!ActorStateNotificationPublished(actor_id, std::move(copied))) { + return false; + } + + auto subscribe_finished_callback = subscribe_finished_callback_it->second; + subscribe_finished_callback(Status::OK()); + // Erase callback when actor subscribe is finished. + subscribe_finished_callback_map_.erase(subscribe_finished_callback_it); + return true; + } + + absl::flat_hash_map> + callback_map_; + absl::flat_hash_map subscribe_finished_callback_map_; + absl::flat_hash_map actor_subscribed_times_; + + // Callbacks for AsyncCreateActor and AsyncRegisterActor + rpc::ClientCallback async_create_actor_callback_; + gcs::StatusCallback async_register_actor_callback_; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/mock/ray/gcs_client/gcs_client.h b/src/mock/ray/gcs_client/gcs_client.h index 1e94406ae09e..2aeb68a7760b 100644 --- a/src/mock/ray/gcs_client/gcs_client.h +++ b/src/mock/ray/gcs_client/gcs_client.h @@ -14,6 +14,7 @@ #pragma once +#include "accessors/actor_info_accessor.h" #include "mock/ray/gcs_client/accessor.h" #include "ray/gcs_rpc_client/gcs_client.h" @@ -40,9 +41,9 @@ class MockGcsClient : public GcsClient { MOCK_METHOD((std::pair), GetGcsServerAddress, (), (const, override)); MOCK_METHOD(std::string, DebugString, (), (const, override)); - MockGcsClient() { + MockGcsClient() : GcsClient(MockGcsClientOptions()) { mock_job_accessor = new MockJobInfoAccessor(); - mock_actor_accessor = new MockActorInfoAccessor(); + mock_actor_accessor = new FakeActorInfoAccessor(); mock_node_accessor = new MockNodeInfoAccessor(); mock_node_resource_accessor = new MockNodeResourceInfoAccessor(); mock_error_accessor = new MockErrorInfoAccessor(); @@ -61,7 +62,7 @@ class MockGcsClient : public GcsClient { GcsClient::internal_kv_accessor_.reset(mock_internal_kv_accessor); GcsClient::task_accessor_.reset(mock_task_accessor); } - MockActorInfoAccessor *mock_actor_accessor; + FakeActorInfoAccessor *mock_actor_accessor; MockJobInfoAccessor *mock_job_accessor; MockNodeInfoAccessor *mock_node_accessor; MockNodeResourceInfoAccessor *mock_node_resource_accessor; diff --git a/src/ray/core_worker/actor_creator.h b/src/ray/core_worker/actor_creator.h index cb15b869359f..974d1360be9f 100644 --- a/src/ray/core_worker/actor_creator.h +++ b/src/ray/core_worker/actor_creator.h @@ -18,7 +18,7 @@ #include #include -#include "ray/gcs_rpc_client/accessor.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" #include "ray/util/thread_utils.h" namespace ray { @@ -74,7 +74,7 @@ class ActorCreatorInterface { class ActorCreator : public ActorCreatorInterface { public: - explicit ActorCreator(gcs::ActorInfoAccessor &actor_client) + explicit ActorCreator(gcs::ActorInfoAccessorInterface &actor_client) : actor_client_(actor_client) {} Status RegisterActor(const TaskSpecification &task_spec) const override; @@ -101,7 +101,7 @@ class ActorCreator : public ActorCreatorInterface { const rpc::ClientCallback &callback) override; private: - gcs::ActorInfoAccessor &actor_client_; + gcs::ActorInfoAccessorInterface &actor_client_; using RegisteringActorType = absl::flat_hash_map>; ThreadPrivate registering_actors_; diff --git a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc index 75e1a8034180..a8267ae337ac 100644 --- a/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc +++ b/src/ray/core_worker/task_submission/tests/direct_actor_transport_test.cc @@ -92,12 +92,9 @@ TEST_F(DirectTaskTransportTest, ActorCreationOk) { auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); auto creation_task_spec = GetActorCreationTaskSpec(actor_id); EXPECT_CALL(*task_manager, CompletePendingTask(creation_task_spec.TaskId(), _, _, _)); - rpc::ClientCallback create_cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncCreateActor(creation_task_spec, ::testing::_)) - .WillOnce(::testing::DoAll(::testing::SaveArg<1>(&create_cb))); actor_task_submitter->SubmitActorCreationTask(creation_task_spec); - create_cb(Status::OK(), rpc::CreateActorReply()); + gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::OK(), + rpc::CreateActorReply()); } TEST_F(DirectTaskTransportTest, ActorCreationFail) { @@ -108,12 +105,9 @@ TEST_F(DirectTaskTransportTest, ActorCreationFail) { *task_manager, FailPendingTask( creation_task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _)); - rpc::ClientCallback create_cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncCreateActor(creation_task_spec, ::testing::_)) - .WillOnce(::testing::DoAll(::testing::SaveArg<1>(&create_cb))); actor_task_submitter->SubmitActorCreationTask(creation_task_spec); - create_cb(Status::IOError(""), rpc::CreateActorReply()); + gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::IOError(""), + rpc::CreateActorReply()); } TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { @@ -125,10 +119,6 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { auto task_arg = task_spec.GetMutableMessage().add_args(); auto inline_obj_ref = task_arg->add_nested_inlined_refs(); inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); - std::function register_cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncRegisterActor(creation_task_spec, ::testing::_, ::testing::_)) - .WillOnce(::testing::DoAll(::testing::SaveArg<1>(®ister_cb))); actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); actor_task_submitter->AddActorQueueIfNotExists(actor_id, @@ -141,7 +131,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) { *task_manager, FailOrRetryPendingTask( task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _, _)); - register_cb(Status::IOError("")); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::IOError("")); } TEST_F(DirectTaskTransportTest, ActorRegisterOk) { @@ -153,10 +143,6 @@ TEST_F(DirectTaskTransportTest, ActorRegisterOk) { auto task_arg = task_spec.GetMutableMessage().add_args(); auto inline_obj_ref = task_arg->add_nested_inlined_refs(); inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary()); - std::function register_cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncRegisterActor(creation_task_spec, ::testing::_, ::testing::_)) - .WillOnce(::testing::DoAll(::testing::SaveArg<1>(®ister_cb))); actor_creator->AsyncRegisterActor(creation_task_spec, nullptr); ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); actor_task_submitter->AddActorQueueIfNotExists(actor_id, @@ -166,7 +152,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterOk) { /*owned*/ false); ASSERT_TRUE(CheckSubmitTask(task_spec)); EXPECT_CALL(*task_manager, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0); - register_cb(Status::OK()); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK()); } } // namespace core diff --git a/src/ray/core_worker/tests/actor_creator_test.cc b/src/ray/core_worker/tests/actor_creator_test.cc index 10d3b3574c3e..9441612df49a 100644 --- a/src/ray/core_worker/tests/actor_creator_test.cc +++ b/src/ray/core_worker/tests/actor_creator_test.cc @@ -49,23 +49,15 @@ TEST_F(ActorCreatorTest, IsRegister) { auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id)); auto task_spec = GetTaskSpec(actor_id); - std::function cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncRegisterActor(task_spec, ::testing::_, ::testing::_)) - .WillOnce(::testing::DoAll(::testing::SaveArg<1>(&cb))); actor_creator->AsyncRegisterActor(task_spec, nullptr); ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id)); - cb(Status::OK()); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK()); ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id)); } TEST_F(ActorCreatorTest, AsyncWaitForFinish) { auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000"); auto task_spec = GetTaskSpec(actor_id); - std::function cb; - EXPECT_CALL(*gcs_client->mock_actor_accessor, - AsyncRegisterActor(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::DoAll(::testing::SaveArg<1>(&cb))); int count = 0; auto per_finish_cb = [&count](Status status) { ASSERT_TRUE(status.ok()); @@ -76,7 +68,7 @@ TEST_F(ActorCreatorTest, AsyncWaitForFinish) { for (int i = 0; i < 10; ++i) { actor_creator->AsyncWaitForActorRegisterFinish(actor_id, per_finish_cb); } - cb(Status::OK()); + gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK()); ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id)); ASSERT_EQ(11, count); } diff --git a/src/ray/core_worker/tests/actor_manager_test.cc b/src/ray/core_worker/tests/actor_manager_test.cc index 792158f4790d..f16bec11ca2d 100644 --- a/src/ray/core_worker/tests/actor_manager_test.cc +++ b/src/ray/core_worker/tests/actor_manager_test.cc @@ -17,12 +17,14 @@ #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "mock/ray/core_worker/reference_counter.h" +#include "mock/ray/gcs_client/accessors/actor_info_accessor.h" #include "ray/common/test_utils.h" -#include "ray/gcs_rpc_client/accessor.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" #include "ray/gcs_rpc_client/gcs_client.h" namespace ray { @@ -30,68 +32,11 @@ namespace core { using ::testing::_; -class MockActorInfoAccessor : public gcs::ActorInfoAccessor { - public: - explicit MockActorInfoAccessor(gcs::GcsClient *client) - : gcs::ActorInfoAccessor(client) {} - - ~MockActorInfoAccessor() {} - - void AsyncSubscribe( - const ActorID &actor_id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) { - auto callback_entry = std::make_pair(actor_id, subscribe); - callback_map_.emplace(actor_id, subscribe); - subscribe_finished_callback_map_[actor_id] = done; - actor_subscribed_times_[actor_id]++; - } - - bool ActorStateNotificationPublished(const ActorID &actor_id, - const rpc::ActorTableData &actor_data) { - auto it = callback_map_.find(actor_id); - if (it == callback_map_.end()) return false; - auto actor_state_notification_callback = it->second; - auto copied = actor_data; - actor_state_notification_callback(actor_id, std::move(copied)); - return true; - } - - bool CheckSubscriptionRequested(const ActorID &actor_id) { - return callback_map_.find(actor_id) != callback_map_.end(); - } - - // Mock the logic of subscribe finished. see `ActorInfoAccessor::AsyncSubscribe` - bool ActorSubscribeFinished(const ActorID &actor_id, - const rpc::ActorTableData &actor_data) { - auto subscribe_finished_callback_it = subscribe_finished_callback_map_.find(actor_id); - if (subscribe_finished_callback_it == subscribe_finished_callback_map_.end()) { - return false; - } - - auto copied = actor_data; - if (!ActorStateNotificationPublished(actor_id, std::move(copied))) { - return false; - } - - auto subscribe_finished_callback = subscribe_finished_callback_it->second; - subscribe_finished_callback(Status::OK()); - // Erase callback when actor subscribe is finished. - subscribe_finished_callback_map_.erase(subscribe_finished_callback_it); - return true; - } - - absl::flat_hash_map> - callback_map_; - absl::flat_hash_map subscribe_finished_callback_map_; - absl::flat_hash_map actor_subscribed_times_; -}; - class MockGcsClient : public gcs::GcsClient { public: explicit MockGcsClient(gcs::GcsClientOptions options) : gcs::GcsClient(options) {} - void Init(MockActorInfoAccessor *actor_info_accessor) { + void Init(gcs::FakeActorInfoAccessor *actor_info_accessor) { actor_accessor_.reset(actor_info_accessor); } }; @@ -132,7 +77,7 @@ class ActorManagerTest : public ::testing::Test { /*allow_cluster_id_nil=*/true, /*fetch_cluster_id_if_nil=*/false), gcs_client_mock_(new MockGcsClient(options_)), - actor_info_accessor_(new MockActorInfoAccessor(gcs_client_mock_.get())), + actor_info_accessor_(new gcs::FakeActorInfoAccessor()), actor_task_submitter_(new MockActorTaskSubmitter()), reference_counter_(new MockReferenceCounter()) { gcs_client_mock_->Init(actor_info_accessor_); @@ -182,7 +127,7 @@ class ActorManagerTest : public ::testing::Test { gcs::GcsClientOptions options_; std::shared_ptr gcs_client_mock_; - MockActorInfoAccessor *actor_info_accessor_; + gcs::FakeActorInfoAccessor *actor_info_accessor_; std::shared_ptr actor_task_submitter_; std::unique_ptr reference_counter_; std::shared_ptr actor_manager_; diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index e1cabcbd2c8b..3e280b0d2eee 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -1075,11 +1075,6 @@ TEST_P(HandleWaitForActorRefDeletedWhileRegisteringRetriesTest, actor_creation_spec->set_max_task_retries(0); TaskSpecification task_spec(task_spec_msg); - gcs::StatusCallback register_callback; - EXPECT_CALL(*mock_gcs_client_->mock_actor_accessor, - AsyncRegisterActor(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::SaveArg<1>(®ister_callback)); - actor_creator_->AsyncRegisterActor(task_spec, nullptr); ASSERT_TRUE(actor_creator_->IsActorInRegistering(actor_id)); @@ -1113,7 +1108,7 @@ TEST_P(HandleWaitForActorRefDeletedWhileRegisteringRetriesTest, }); ASSERT_EQ(callback_count, 0); - register_callback(Status::OK()); + mock_gcs_client_->mock_actor_accessor->async_register_actor_callback_(Status::OK()); // Triggers the callbacks passed to AsyncWaitForActorRegisterFinish ASSERT_FALSE(actor_creator_->IsActorInRegistering(actor_id)); diff --git a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc index d19d350f9710..727168213ef3 100644 --- a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc +++ b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc @@ -127,7 +127,9 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { class MockGcsClient : public gcs::GcsClient { public: - explicit MockGcsClient(bool is_subscribed_to_node_change) { + explicit MockGcsClient(bool is_subscribed_to_node_change, + gcs::GcsClientOptions &options) + : GcsClient(options) { this->node_accessor_ = std::make_unique(is_subscribed_to_node_change); } @@ -141,7 +143,12 @@ class DefaultUnavailableTimeoutCallbackTest : public ::testing::TestWithParam([](const rpc::Address &) { return std::make_shared(); })), @@ -156,6 +163,7 @@ class DefaultUnavailableTimeoutCallbackTest : public ::testing::TestWithParam raylet_client_pool_; std::unique_ptr client_pool_; diff --git a/src/ray/gcs_rpc_client/BUILD.bazel b/src/ray/gcs_rpc_client/BUILD.bazel index e55b832275de..8665009c0e64 100644 --- a/src/ray/gcs_rpc_client/BUILD.bazel +++ b/src/ray/gcs_rpc_client/BUILD.bazel @@ -11,6 +11,10 @@ ray_cc_library( "gcs_client.h", ], deps = [ + ":accessor_factory_interface", + ":actor_accessor_interface", + ":gcs_client_context", + ":gcs_client_defaults", ":rpc_client", "//src/ray/common:asio", "//src/ray/common:id", @@ -36,6 +40,91 @@ ray_cc_library( ], ) +ray_cc_library( + name = "accessor_factory_interface", + hdrs = [ + "accessor_factory_interface.h", + ], + visibility = ["//visibility:public"], + deps = [], +) + +ray_cc_library( + name = "actor_accessor_interface", + hdrs = [ + "accessors/actor_info_accessor_interface.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:gcs_callback_types", + "//src/ray/common:id", + "//src/ray/common:placement_group", + "//src/ray/common:task_common", + "//src/ray/protobuf:autoscaler_cc_proto", + "//src/ray/protobuf:gcs_cc_proto", + ], +) + +ray_cc_library( + name = "gcs_client_context", + hdrs = [ + "gcs_client_context.h", + ], + visibility = [":__pkg__"], + deps = [ + ":rpc_client", + "//src/ray/pubsub:gcs_subscriber", + ], +) + +ray_cc_library( + name = "actor_accessor_implementation", + srcs = [ + "accessors/actor_info_accessor.cc", + ], + hdrs = [ + "accessors/actor_info_accessor.h", + ], + visibility = [":__pkg__"], + deps = [ + ":actor_accessor_interface", + ":gcs_client_context", + ":rpc_client", + "//src/ray/common:asio", + "//src/ray/common:id", + "//src/ray/common:placement_group", + "//src/ray/common:protobuf_utils", + "//src/ray/gcs/store_client:redis_store_client", + "//src/ray/protobuf:usage_cc_proto", + "//src/ray/pubsub:gcs_subscriber", + "//src/ray/pubsub:subscriber", + "//src/ray/util:container_util", + "//src/ray/util:network_util", + "//src/ray/util:sequencer", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +ray_cc_library( + name = "gcs_client_defaults", + srcs = [ + "default_accessor_factory.cc", + "default_gcs_client_context.cc", + ], + hdrs = [ + "default_accessor_factory.h", + "default_gcs_client_context.h", + ], + visibility = [":__pkg__"], + deps = [ + ":accessor_factory_interface", + ":actor_accessor_implementation", + ":actor_accessor_interface", + ":gcs_client_context", + ], +) + ray_cc_library( name = "rpc_client", hdrs = [ diff --git a/src/ray/gcs_rpc_client/accessor.cc b/src/ray/gcs_rpc_client/accessor.cc index 2447446e6da3..68ec43e04464 100644 --- a/src/ray/gcs_rpc_client/accessor.cc +++ b/src/ray/gcs_rpc_client/accessor.cc @@ -28,11 +28,6 @@ namespace ray { namespace gcs { -int64_t GetGcsTimeoutMs() { - return absl::ToInt64Milliseconds( - absl::Seconds(RayConfig::instance().gcs_server_request_timeout_seconds())); -} - JobInfoAccessor::JobInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void JobInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, @@ -164,307 +159,6 @@ void JobInfoAccessor::AsyncGetNextJobID(const ItemCallback &callback) { }); } -ActorInfoAccessor::ActorInfoAccessor(GcsClient *client_impl) - : client_impl_(client_impl) {} - -void ActorInfoAccessor::AsyncGet( - const ActorID &actor_id, const OptionalItemCallback &callback) { - RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) << "Getting actor info"; - rpc::GetActorInfoRequest request; - request.set_actor_id(actor_id.Binary()); - client_impl_->GetGcsRpcClient().GetActorInfo( - std::move(request), - [actor_id, callback](const Status &status, rpc::GetActorInfoReply &&reply) { - if (reply.has_actor_table_data()) { - callback(status, reply.actor_table_data()); - } else { - callback(status, std::nullopt); - } - RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) - << "Finished getting actor info, status = " << status; - }); -} - -void ActorInfoAccessor::AsyncGetAllByFilter( - const std::optional &actor_id, - const std::optional &job_id, - const std::optional &actor_state_name, - const MultiItemCallback &callback, - int64_t timeout_ms) { - RAY_LOG(DEBUG) << "Getting all actor info."; - rpc::GetAllActorInfoRequest request; - if (actor_id) { - request.mutable_filters()->set_actor_id(actor_id.value().Binary()); - } - if (job_id) { - request.mutable_filters()->set_job_id(job_id.value().Binary()); - } - if (actor_state_name) { - static absl::flat_hash_map - actor_state_map = { - {"DEPENDENCIES_UNREADY", rpc::ActorTableData::DEPENDENCIES_UNREADY}, - {"PENDING_CREATION", rpc::ActorTableData::PENDING_CREATION}, - {"ALIVE", rpc::ActorTableData::ALIVE}, - {"RESTARTING", rpc::ActorTableData::RESTARTING}, - {"DEAD", rpc::ActorTableData::DEAD}}; - request.mutable_filters()->set_state(actor_state_map[*actor_state_name]); - } - - client_impl_->GetGcsRpcClient().GetAllActorInfo( - std::move(request), - [callback](const Status &status, rpc::GetAllActorInfoReply &&reply) { - callback(status, - VectorFromProtobuf(std::move(*reply.mutable_actor_table_data()))); - RAY_LOG(DEBUG) << "Finished getting all actor info, status = " << status; - }, - timeout_ms); -} - -void ActorInfoAccessor::AsyncGetByName( - const std::string &name, - const std::string &ray_namespace, - const OptionalItemCallback &callback, - int64_t timeout_ms) { - RAY_LOG(DEBUG) << "Getting actor info, name = " << name; - rpc::GetNamedActorInfoRequest request; - request.set_name(name); - request.set_ray_namespace(ray_namespace); - client_impl_->GetGcsRpcClient().GetNamedActorInfo( - std::move(request), - [name, callback](const Status &status, rpc::GetNamedActorInfoReply &&reply) { - if (reply.has_actor_table_data()) { - callback(status, reply.actor_table_data()); - } else { - callback(status, std::nullopt); - } - RAY_LOG(DEBUG) << "Finished getting actor info, status = " << status - << ", name = " << name; - }, - timeout_ms); -} - -Status ActorInfoAccessor::SyncGetByName(const std::string &name, - const std::string &ray_namespace, - rpc::ActorTableData &actor_table_data, - rpc::TaskSpec &task_spec) { - rpc::GetNamedActorInfoRequest request; - rpc::GetNamedActorInfoReply reply; - request.set_name(name); - request.set_ray_namespace(ray_namespace); - auto status = client_impl_->GetGcsRpcClient().SyncGetNamedActorInfo( - std::move(request), &reply, GetGcsTimeoutMs()); - if (status.ok()) { - actor_table_data = std::move(*reply.mutable_actor_table_data()); - task_spec = std::move(*reply.mutable_task_spec()); - } - return status; -} - -Status ActorInfoAccessor::SyncListNamedActors( - bool all_namespaces, - const std::string &ray_namespace, - std::vector> &actors) { - rpc::ListNamedActorsRequest request; - request.set_all_namespaces(all_namespaces); - request.set_ray_namespace(ray_namespace); - rpc::ListNamedActorsReply reply; - auto status = client_impl_->GetGcsRpcClient().SyncListNamedActors( - std::move(request), &reply, GetGcsTimeoutMs()); - if (!status.ok()) { - return status; - } - actors.reserve(reply.named_actors_list_size()); - for (auto &actor_info : - VectorFromProtobuf(std::move(*reply.mutable_named_actors_list()))) { - actors.emplace_back(std::move(*actor_info.mutable_ray_namespace()), - std::move(*actor_info.mutable_name())); - } - return status; -} - -void ActorInfoAccessor::AsyncRestartActorForLineageReconstruction( - const ray::ActorID &actor_id, - uint64_t num_restarts_due_to_lineage_reconstruction, - const ray::gcs::StatusCallback &callback, - int64_t timeout_ms) { - rpc::RestartActorForLineageReconstructionRequest request; - request.set_actor_id(actor_id.Binary()); - request.set_num_restarts_due_to_lineage_reconstruction( - num_restarts_due_to_lineage_reconstruction); - client_impl_->GetGcsRpcClient().RestartActorForLineageReconstruction( - std::move(request), - [callback](const Status &status, - rpc::RestartActorForLineageReconstructionReply &&reply) { - callback(status); - }, - timeout_ms); -} - -namespace { - -// TODO(dayshah): Yes this is temporary. https://github.com/ray-project/ray/issues/54327 -Status ComputeGcsStatus(const Status &grpc_status, const rpc::GcsStatus &gcs_status) { - // If gRPC status is ok return the GCS status, otherwise return the gRPC status. - if (grpc_status.ok()) { - return gcs_status.code() == static_cast(StatusCode::OK) - ? Status::OK() - : Status(StatusCode(gcs_status.code()), gcs_status.message()); - } else { - return grpc_status; - } -} - -} // namespace - -void ActorInfoAccessor::AsyncRegisterActor(const ray::TaskSpecification &task_spec, - const ray::gcs::StatusCallback &callback, - int64_t timeout_ms) { - RAY_CHECK(task_spec.IsActorCreationTask() && callback); - rpc::RegisterActorRequest request; - request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); - client_impl_->GetGcsRpcClient().RegisterActor( - std::move(request), - [callback](const Status &status, rpc::RegisterActorReply &&reply) { - callback(ComputeGcsStatus(status, reply.status())); - }, - timeout_ms); -} - -Status ActorInfoAccessor::SyncRegisterActor(const ray::TaskSpecification &task_spec) { - RAY_CHECK(task_spec.IsActorCreationTask()); - rpc::RegisterActorRequest request; - rpc::RegisterActorReply reply; - request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); - auto status = client_impl_->GetGcsRpcClient().SyncRegisterActor( - std::move(request), &reply, GetGcsTimeoutMs()); - return ComputeGcsStatus(status, reply.status()); -} - -void ActorInfoAccessor::AsyncKillActor(const ActorID &actor_id, - bool force_kill, - bool no_restart, - const ray::gcs::StatusCallback &callback, - int64_t timeout_ms) { - rpc::KillActorViaGcsRequest request; - request.set_actor_id(actor_id.Binary()); - request.set_force_kill(force_kill); - request.set_no_restart(no_restart); - client_impl_->GetGcsRpcClient().KillActorViaGcs( - std::move(request), - [callback](const Status &status, rpc::KillActorViaGcsReply &&reply) { - if (callback) { - callback(status); - } - }, - timeout_ms); -} - -void ActorInfoAccessor::AsyncCreateActor( - const ray::TaskSpecification &task_spec, - const rpc::ClientCallback &callback) { - RAY_CHECK(task_spec.IsActorCreationTask() && callback); - rpc::CreateActorRequest request; - request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); - client_impl_->GetGcsRpcClient().CreateActor( - std::move(request), - [callback](const Status &status, rpc::CreateActorReply &&reply) { - callback(status, std::move(reply)); - }); -} - -void ActorInfoAccessor::AsyncReportActorOutOfScope( - const ActorID &actor_id, - uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, - int64_t timeout_ms) { - rpc::ReportActorOutOfScopeRequest request; - request.set_actor_id(actor_id.Binary()); - request.set_num_restarts_due_to_lineage_reconstruction( - num_restarts_due_to_lineage_reconstruction); - client_impl_->GetGcsRpcClient().ReportActorOutOfScope( - std::move(request), - [callback](const Status &status, rpc::ReportActorOutOfScopeReply &&reply) { - if (callback) { - callback(status); - } - }, - timeout_ms); -} - -void ActorInfoAccessor::AsyncSubscribe( - const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) - << "Subscribing update operations of actor"; - RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; - - auto fetch_data_operation = - [this, actor_id, subscribe](const StatusCallback &fetch_done) { - auto callback = [actor_id, subscribe, fetch_done]( - const Status &status, - std::optional &&result) { - if (result) { - subscribe(actor_id, std::move(*result)); - } - if (fetch_done) { - fetch_done(status); - } - }; - AsyncGet(actor_id, callback); - }; - - { - absl::MutexLock lock(&mutex_); - resubscribe_operations_[actor_id] = - [this, actor_id, subscribe](const StatusCallback &subscribe_done) { - client_impl_->GetGcsSubscriber().SubscribeActor( - actor_id, subscribe, subscribe_done); - }; - fetch_data_operations_[actor_id] = fetch_data_operation; - } - - client_impl_->GetGcsSubscriber().SubscribeActor( - actor_id, subscribe, [fetch_data_operation, done](const Status &) { - fetch_data_operation(done); - }); -} - -void ActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { - RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) - << "Cancelling subscription to an actor"; - client_impl_->GetGcsSubscriber().UnsubscribeActor(actor_id); - absl::MutexLock lock(&mutex_); - resubscribe_operations_.erase(actor_id); - fetch_data_operations_.erase(actor_id); - RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) - << "Finished cancelling subscription to an actor"; -} - -void ActorInfoAccessor::AsyncResubscribe() { - RAY_LOG(DEBUG) << "Reestablishing subscription for actor info."; - // If only the GCS sever has restarted, we only need to fetch data from the GCS server. - // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub - // server first, then fetch data from the GCS server. - absl::MutexLock lock(&mutex_); - for (auto &[actor_id, resubscribe_op] : resubscribe_operations_) { - resubscribe_op([this, id = actor_id](const Status &status) { - absl::MutexLock callback_lock(&mutex_); - auto fetch_data_operation = fetch_data_operations_[id]; - // `fetch_data_operation` is called in the callback function of subscribe. - // Before that, if the user calls `AsyncUnsubscribe` function, the corresponding - // fetch function will be deleted, so we need to check if it's null. - if (fetch_data_operation != nullptr) { - fetch_data_operation(nullptr); - } - }); - } -} - -bool ActorInfoAccessor::IsActorUnsubscribed(const ActorID &actor_id) { - return client_impl_->GetGcsSubscriber().IsActorUnsubscribed(actor_id); -} - NodeInfoAccessor::NodeInfoAccessor(GcsClient *client_impl) : client_impl_(client_impl) {} void NodeInfoAccessor::RegisterSelf(rpc::GcsNodeInfo &&local_node_info, @@ -1124,7 +818,7 @@ Status PlacementGroupInfoAccessor::SyncCreatePlacementGroup( rpc::CreatePlacementGroupReply reply; request.mutable_placement_group_spec()->CopyFrom(placement_group_spec.GetMessage()); auto status = client_impl_->GetGcsRpcClient().SyncCreatePlacementGroup( - std::move(request), &reply, GetGcsTimeoutMs()); + std::move(request), &reply, rpc::GetGcsTimeoutMs()); if (status.ok()) { RAY_LOG(DEBUG).WithField(placement_group_spec.PlacementGroupId()) << "Finished registering placement group."; @@ -1141,7 +835,7 @@ Status PlacementGroupInfoAccessor::SyncRemovePlacementGroup( rpc::RemovePlacementGroupReply reply; request.set_placement_group_id(placement_group_id.Binary()); auto status = client_impl_->GetGcsRpcClient().SyncRemovePlacementGroup( - std::move(request), &reply, GetGcsTimeoutMs()); + std::move(request), &reply, rpc::GetGcsTimeoutMs()); return status; } diff --git a/src/ray/gcs_rpc_client/accessor.h b/src/ray/gcs_rpc_client/accessor.h index eedb84be57ce..e1b5e09aa1a4 100644 --- a/src/ray/gcs_rpc_client/accessor.h +++ b/src/ray/gcs_rpc_client/accessor.h @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include "absl/types/optional.h" #include "ray/common/gcs_callback_types.h" #include "ray/common/id.h" #include "ray/common/placement_group.h" @@ -34,180 +32,10 @@ namespace ray { namespace gcs { -// Default GCS Client timeout in milliseconds, as defined in -// RAY_gcs_server_request_timeout_seconds -int64_t GetGcsTimeoutMs(); - using SubscribeOperation = std::function; using FetchDataOperation = std::function; class GcsClient; - -/// \class ActorInfoAccessor -/// `ActorInfoAccessor` is a sub-interface of `GcsClient`. -/// This class includes all the methods that are related to accessing -/// actor information in the GCS. -class ActorInfoAccessor { - public: - ActorInfoAccessor() = default; - explicit ActorInfoAccessor(GcsClient *client_impl); - virtual ~ActorInfoAccessor() = default; - /// Get actor specification from GCS asynchronously. - /// - /// \param actor_id The ID of actor to look up in the GCS. - /// \param callback Callback that will be called after lookup finishes. - virtual void AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback); - - /// Get all actor specification from the GCS asynchronously. - /// - /// \param actor_id To filter actors by actor_id. - /// \param job_id To filter actors by job_id. - /// \param actor_state_name To filter actors based on actor state. - /// \param callback Callback that will be called after lookup finishes. - /// \param timeout_ms -1 means infinite. - virtual void AsyncGetAllByFilter(const std::optional &actor_id, - const std::optional &job_id, - const std::optional &actor_state_name, - const MultiItemCallback &callback, - int64_t timeout_ms = -1); - - /// Get actor specification for a named actor from the GCS asynchronously. - /// - /// \param name The name of the detached actor to look up in the GCS. - /// \param ray_namespace The namespace to filter to. - /// \param callback Callback that will be called after lookup finishes. - /// \param timeout_ms RPC timeout in milliseconds. -1 means the default. - virtual void AsyncGetByName(const std::string &name, - const std::string &ray_namespace, - const OptionalItemCallback &callback, - int64_t timeout_ms = -1); - - /// Get actor specification for a named actor from the GCS synchronously. - /// - /// The RPC will timeout after the default GCS RPC timeout is exceeded. - /// - /// \param name The name of the detached actor to look up in the GCS. - /// \param ray_namespace The namespace to filter to. - /// \return Status. TimedOut status if RPC is timed out. - /// NotFound if the name doesn't exist. - virtual Status SyncGetByName(const std::string &name, - const std::string &ray_namespace, - rpc::ActorTableData &actor_table_data, - rpc::TaskSpec &task_spec); - - /// List all named actors from the GCS synchronously. - /// - /// The RPC will timeout after the default GCS RPC timeout is exceeded. - /// - /// \param all_namespaces Whether or not to include actors from all Ray namespaces. - /// \param ray_namespace The namespace to filter to if all_namespaces is false. - /// \param[out] actors The pair of list of named actors. Each pair includes the - /// namespace and name of the actor. \return Status. TimeOut if RPC times out. - virtual Status SyncListNamedActors( - bool all_namespaces, - const std::string &ray_namespace, - std::vector> &actors); - - virtual void AsyncReportActorOutOfScope( - const ActorID &actor_id, - uint64_t num_restarts_due_to_lineage_reconstruction, - const StatusCallback &callback, - int64_t timeout_ms = -1); - - /// Register actor to GCS asynchronously. - /// - /// \param task_spec The specification for the actor creation task. - /// \param callback Callback that will be called after the actor info is written to GCS. - /// \param timeout_ms RPC timeout ms. -1 means there's no timeout. - virtual void AsyncRegisterActor(const TaskSpecification &task_spec, - const StatusCallback &callback, - int64_t timeout_ms = -1); - - virtual void AsyncRestartActorForLineageReconstruction( - const ActorID &actor_id, - uint64_t num_restarts_due_to_lineage_reconstructions, - const StatusCallback &callback, - int64_t timeout_ms = -1); - - /// Register actor to GCS synchronously. - /// - /// The RPC will timeout after the default GCS RPC timeout is exceeded. - /// - /// \param task_spec The specification for the actor creation task. - /// \return Status. Timedout if actor is not registered by the global - /// GCS timeout. - virtual Status SyncRegisterActor(const ray::TaskSpecification &task_spec); - - /// Kill actor via GCS asynchronously. - /// - /// \param actor_id The ID of actor to destroy. - /// \param force_kill Whether to force kill an actor by killing the worker. - /// \param no_restart If set to true, the killed actor will not be restarted anymore. - /// \param callback Callback that will be called after the actor is destroyed. - /// \param timeout_ms RPC timeout in milliseconds. -1 means infinite. - virtual void AsyncKillActor(const ActorID &actor_id, - bool force_kill, - bool no_restart, - const StatusCallback &callback, - int64_t timeout_ms = -1); - - /// Asynchronously request GCS to create the actor. - /// - /// This should be called after the worker has resolved the actor dependencies. - /// TODO(...): Currently this request will only reply after the actor is created. - /// We should change it to reply immediately after GCS has persisted the actor - /// dependencies in storage. - /// - /// \param task_spec The specification for the actor creation task. - /// \param callback Callback that will be called after the actor info is written to GCS. - virtual void AsyncCreateActor( - const TaskSpecification &task_spec, - const rpc::ClientCallback &callback); - - /// Subscribe to any update operations of an actor. - /// - /// \param actor_id The ID of actor to be subscribed to. - /// \param subscribe Callback that will be called each time when the actor is updated. - /// \param done Callback that will be called when subscription is complete. - virtual void AsyncSubscribe( - const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done); - - /// Cancel subscription to an actor. - /// - /// \param actor_id The ID of the actor to be unsubscribed to. - virtual void AsyncUnsubscribe(const ActorID &actor_id); - - /// Reestablish subscription. - /// This should be called when GCS server restarts from a failure. - /// PubSub server restart will cause GCS server restart. In this case, we need to - /// resubscribe from PubSub server, otherwise we only need to fetch data from GCS - /// server. - virtual void AsyncResubscribe(); - - /// Check if the specified actor is unsubscribed. - /// - /// \param actor_id The ID of the actor. - /// \return Whether the specified actor is unsubscribed. - virtual bool IsActorUnsubscribed(const ActorID &actor_id); - - private: - // Mutex to protect the resubscribe_operations_ field and fetch_data_operations_ field. - absl::Mutex mutex_; - - /// Resubscribe operations for actors. - absl::flat_hash_map resubscribe_operations_ - ABSL_GUARDED_BY(mutex_); - - /// Save the fetch data operation of actors. - absl::flat_hash_map fetch_data_operations_ - ABSL_GUARDED_BY(mutex_); - - GcsClient *client_impl_; -}; - /// \class JobInfoAccessor /// `JobInfoAccessor` is a sub-interface of `GcsClient`. /// This class includes all the methods that are related to accessing diff --git a/src/ray/gcs_rpc_client/accessor_factory_interface.h b/src/ray/gcs_rpc_client/accessor_factory_interface.h new file mode 100644 index 000000000000..03d4ec445801 --- /dev/null +++ b/src/ray/gcs_rpc_client/accessor_factory_interface.h @@ -0,0 +1,47 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace gcs { + +// Forward declarations for accessor interfaces +class ActorInfoAccessorInterface; +class GcsClientContext; + +/** + @interface AccessorFactoryInterface + + Factory interface for creating GCS accessor instances. + */ +class AccessorFactoryInterface { + public: + virtual ~AccessorFactoryInterface() = default; + + /** + Create an ActorInfoAccessor instance. + + @param context The GCS client implementation. + @return unique_ptr A unique pointer to the created + ActorInfoAccessor. + */ + virtual std::unique_ptr CreateActorInfoAccessor( + GcsClientContext *context) = 0; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc new file mode 100644 index 000000000000..4d5b3bbff49c --- /dev/null +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.cc @@ -0,0 +1,331 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs_rpc_client/accessors/actor_info_accessor.h" + +#include +#include +#include +#include +#include +#include + +#include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/pubsub/gcs_subscriber.h" +#include "ray/util/container_util.h" + +namespace ray { +namespace gcs { + +ActorInfoAccessor::ActorInfoAccessor(GcsClientContext *context) : context_(context) {} + +void ActorInfoAccessor::AsyncGet( + const ActorID &actor_id, const OptionalItemCallback &callback) { + RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) << "Getting actor info"; + rpc::GetActorInfoRequest request; + request.set_actor_id(actor_id.Binary()); + context_->GetGcsRpcClient().GetActorInfo( + std::move(request), + [actor_id, callback](const Status &status, rpc::GetActorInfoReply &&reply) { + if (reply.has_actor_table_data()) { + callback(status, reply.actor_table_data()); + } else { + callback(status, std::nullopt); + } + RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) + << "Finished getting actor info, status = " << status; + }); +} + +void ActorInfoAccessor::AsyncGetAllByFilter( + const std::optional &actor_id, + const std::optional &job_id, + const std::optional &actor_state_name, + const MultiItemCallback &callback, + int64_t timeout_ms) { + RAY_LOG(DEBUG) << "Getting all actor info."; + rpc::GetAllActorInfoRequest request; + if (actor_id) { + request.mutable_filters()->set_actor_id(actor_id.value().Binary()); + } + if (job_id) { + request.mutable_filters()->set_job_id(job_id.value().Binary()); + } + if (actor_state_name) { + static absl::flat_hash_map + actor_state_map = { + {"DEPENDENCIES_UNREADY", rpc::ActorTableData::DEPENDENCIES_UNREADY}, + {"PENDING_CREATION", rpc::ActorTableData::PENDING_CREATION}, + {"ALIVE", rpc::ActorTableData::ALIVE}, + {"RESTARTING", rpc::ActorTableData::RESTARTING}, + {"DEAD", rpc::ActorTableData::DEAD}}; + request.mutable_filters()->set_state(actor_state_map[*actor_state_name]); + } + + context_->GetGcsRpcClient().GetAllActorInfo( + std::move(request), + [callback](const Status &status, rpc::GetAllActorInfoReply &&reply) { + callback(status, + VectorFromProtobuf(std::move(*reply.mutable_actor_table_data()))); + RAY_LOG(DEBUG) << "Finished getting all actor info, status = " << status; + }, + timeout_ms); +} + +void ActorInfoAccessor::AsyncGetByName( + const std::string &name, + const std::string &ray_namespace, + const OptionalItemCallback &callback, + int64_t timeout_ms) { + RAY_LOG(DEBUG) << "Getting actor info, name = " << name; + rpc::GetNamedActorInfoRequest request; + request.set_name(name); + request.set_ray_namespace(ray_namespace); + context_->GetGcsRpcClient().GetNamedActorInfo( + std::move(request), + [name, callback](const Status &status, rpc::GetNamedActorInfoReply &&reply) { + if (reply.has_actor_table_data()) { + callback(status, reply.actor_table_data()); + } else { + callback(status, std::nullopt); + } + RAY_LOG(DEBUG) << "Finished getting actor info, status = " << status + << ", name = " << name; + }, + timeout_ms); +} + +Status ActorInfoAccessor::SyncGetByName(const std::string &name, + const std::string &ray_namespace, + rpc::ActorTableData &actor_table_data, + rpc::TaskSpec &task_spec) { + rpc::GetNamedActorInfoRequest request; + rpc::GetNamedActorInfoReply reply; + request.set_name(name); + request.set_ray_namespace(ray_namespace); + auto status = context_->GetGcsRpcClient().SyncGetNamedActorInfo( + std::move(request), &reply, rpc::GetGcsTimeoutMs()); + if (status.ok()) { + actor_table_data = std::move(*reply.mutable_actor_table_data()); + task_spec = std::move(*reply.mutable_task_spec()); + } + return status; +} + +Status ActorInfoAccessor::SyncListNamedActors( + bool all_namespaces, + const std::string &ray_namespace, + std::vector> &actors) { + rpc::ListNamedActorsRequest request; + request.set_all_namespaces(all_namespaces); + request.set_ray_namespace(ray_namespace); + rpc::ListNamedActorsReply reply; + auto status = context_->GetGcsRpcClient().SyncListNamedActors( + std::move(request), &reply, rpc::GetGcsTimeoutMs()); + if (!status.ok()) { + return status; + } + actors.reserve(reply.named_actors_list_size()); + for (auto &actor_info : + VectorFromProtobuf(std::move(*reply.mutable_named_actors_list()))) { + actors.emplace_back(std::move(*actor_info.mutable_ray_namespace()), + std::move(*actor_info.mutable_name())); + } + return status; +} + +void ActorInfoAccessor::AsyncRestartActorForLineageReconstruction( + const ray::ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstruction, + const ray::gcs::StatusCallback &callback, + int64_t timeout_ms) { + rpc::RestartActorForLineageReconstructionRequest request; + request.set_actor_id(actor_id.Binary()); + request.set_num_restarts_due_to_lineage_reconstruction( + num_restarts_due_to_lineage_reconstruction); + context_->GetGcsRpcClient().RestartActorForLineageReconstruction( + std::move(request), + [callback](const Status &status, + rpc::RestartActorForLineageReconstructionReply &&reply) { + callback(status); + }, + timeout_ms); +} + +namespace { + +// TODO(dayshah): Yes this is temporary. https://github.com/ray-project/ray/issues/54327 +Status ComputeGcsStatus(const Status &grpc_status, const rpc::GcsStatus &gcs_status) { + // If gRPC status is ok return the GCS status, otherwise return the gRPC status. + if (grpc_status.ok()) { + return gcs_status.code() == static_cast(StatusCode::OK) + ? Status::OK() + : Status(StatusCode(gcs_status.code()), gcs_status.message()); + } else { + return grpc_status; + } +} + +} // namespace + +void ActorInfoAccessor::AsyncRegisterActor(const ray::TaskSpecification &task_spec, + const ray::gcs::StatusCallback &callback, + int64_t timeout_ms) { + RAY_CHECK(task_spec.IsActorCreationTask() && callback); + rpc::RegisterActorRequest request; + request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + context_->GetGcsRpcClient().RegisterActor( + std::move(request), + [callback](const Status &status, rpc::RegisterActorReply &&reply) { + callback(ComputeGcsStatus(status, reply.status())); + }, + timeout_ms); +} + +Status ActorInfoAccessor::SyncRegisterActor(const ray::TaskSpecification &task_spec) { + RAY_CHECK(task_spec.IsActorCreationTask()); + rpc::RegisterActorRequest request; + rpc::RegisterActorReply reply; + request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + auto status = context_->GetGcsRpcClient().SyncRegisterActor( + std::move(request), &reply, rpc::GetGcsTimeoutMs()); + return ComputeGcsStatus(status, reply.status()); +} + +void ActorInfoAccessor::AsyncKillActor(const ActorID &actor_id, + bool force_kill, + bool no_restart, + const ray::gcs::StatusCallback &callback, + int64_t timeout_ms) { + rpc::KillActorViaGcsRequest request; + request.set_actor_id(actor_id.Binary()); + request.set_force_kill(force_kill); + request.set_no_restart(no_restart); + context_->GetGcsRpcClient().KillActorViaGcs( + std::move(request), + [callback](const Status &status, rpc::KillActorViaGcsReply &&reply) { + if (callback) { + callback(status); + } + }, + timeout_ms); +} + +void ActorInfoAccessor::AsyncCreateActor( + const ray::TaskSpecification &task_spec, + const rpc::ClientCallback &callback) { + RAY_CHECK(task_spec.IsActorCreationTask() && callback); + rpc::CreateActorRequest request; + request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + context_->GetGcsRpcClient().CreateActor( + std::move(request), + [callback](const Status &status, rpc::CreateActorReply &&reply) { + callback(status, std::move(reply)); + }); +} + +void ActorInfoAccessor::AsyncReportActorOutOfScope( + const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstruction, + const StatusCallback &callback, + int64_t timeout_ms) { + rpc::ReportActorOutOfScopeRequest request; + request.set_actor_id(actor_id.Binary()); + request.set_num_restarts_due_to_lineage_reconstruction( + num_restarts_due_to_lineage_reconstruction); + context_->GetGcsRpcClient().ReportActorOutOfScope( + std::move(request), + [callback](const Status &status, rpc::ReportActorOutOfScopeReply &&reply) { + if (callback) { + callback(status); + } + }, + timeout_ms); +} + +void ActorInfoAccessor::AsyncSubscribe( + const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) + << "Subscribing update operations of actor"; + RAY_CHECK(subscribe != nullptr) << "Failed to subscribe actor, actor id = " << actor_id; + + auto fetch_data_operation = + [this, actor_id, subscribe](const StatusCallback &fetch_done) { + auto callback = [actor_id, subscribe, fetch_done]( + const Status &status, + std::optional &&result) { + if (result) { + subscribe(actor_id, std::move(*result)); + } + if (fetch_done) { + fetch_done(status); + } + }; + AsyncGet(actor_id, callback); + }; + + { + absl::MutexLock lock(&mutex_); + resubscribe_operations_[actor_id] = [this, actor_id, subscribe]( + const StatusCallback &subscribe_done) { + context_->GetGcsSubscriber().SubscribeActor(actor_id, subscribe, subscribe_done); + }; + fetch_data_operations_[actor_id] = fetch_data_operation; + } + + context_->GetGcsSubscriber().SubscribeActor( + actor_id, subscribe, [fetch_data_operation, done](const Status &) { + fetch_data_operation(done); + }); +} + +void ActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { + RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) + << "Cancelling subscription to an actor"; + context_->GetGcsSubscriber().UnsubscribeActor(actor_id); + absl::MutexLock lock(&mutex_); + resubscribe_operations_.erase(actor_id); + fetch_data_operations_.erase(actor_id); + RAY_LOG(DEBUG).WithField(actor_id).WithField(actor_id.JobId()) + << "Finished cancelling subscription to an actor"; +} + +void ActorInfoAccessor::AsyncResubscribe() { + RAY_LOG(DEBUG) << "Reestablishing subscription for actor info."; + // If only the GCS sever has restarted, we only need to fetch data from the GCS server. + // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub + // server first, then fetch data from the GCS server. + absl::MutexLock lock(&mutex_); + for (auto &[actor_id, resubscribe_op] : resubscribe_operations_) { + resubscribe_op([this, id = actor_id](const Status &status) { + absl::MutexLock callback_lock(&mutex_); + auto fetch_data_operation = fetch_data_operations_[id]; + // `fetch_data_operation` is called in the callback function of subscribe. + // Before that, if the user calls `AsyncUnsubscribe` function, the corresponding + // fetch function will be deleted, so we need to check if it's null. + if (fetch_data_operation != nullptr) { + fetch_data_operation(nullptr); + } + }); + } +} + +bool ActorInfoAccessor::IsActorUnsubscribed(const ActorID &actor_id) { + return context_->GetGcsSubscriber().IsActorUnsubscribed(actor_id); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h new file mode 100644 index 000000000000..2572b9adb7f9 --- /dev/null +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor.h @@ -0,0 +1,247 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "absl/synchronization/mutex.h" +#include "ray/common/gcs_callback_types.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" +#include "ray/gcs_rpc_client/gcs_client_context.h" +#include "ray/rpc/rpc_callback_types.h" +#include "ray/util/sequencer.h" + +namespace ray { +namespace gcs { + +using SubscribeOperation = std::function; +using FetchDataOperation = std::function; + +/** + @class ActorInfoAccessor + + Implementation of ActorInfoAccessorInterface that accesses actor information by querying + the GCS. + */ +class ActorInfoAccessor : public ActorInfoAccessorInterface { + public: + ActorInfoAccessor() = default; + explicit ActorInfoAccessor(GcsClientContext *context); + virtual ~ActorInfoAccessor() = default; + /** + Get actor specification from GCS asynchronously. + + @param actor_id The ID of actor to look up in the GCS. + @param callback Callback that will be called after lookup finishes. + */ + void AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) override; + + /** + Get all actor specification from the GCS asynchronously. + + @param actor_id To filter actors by actor_id. + @param job_id To filter actors by job_id. + @param actor_state_name To filter actors based on actor state. + @param callback Callback that will be called after lookup finishes. + @param timeout_ms request timeout, defaults to -1 for infinite timeout. + */ + void AsyncGetAllByFilter(const std::optional &actor_id, + const std::optional &job_id, + const std::optional &actor_state_name, + const MultiItemCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Get actor specification for a named actor from the GCS asynchronously. + + @param name The name of the actor to look up in the GCS. + @param ray_namespace The namespace to filter to. + @param callback Callback that will be called after lookup finishes. + @param timeout_ms RPC timeout in milliseconds. -1 means the default. + */ + void AsyncGetByName(const std::string &name, + const std::string &ray_namespace, + const OptionalItemCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Get actor specification for a named actor from the GCS synchronously. + The RPC will timeout after the default GCS RPC timeout is exceeded. + + @param name The name of the actor to look up in the GCS. + @param ray_namespace The namespace to filter to. NotFound if the name doesn't exist. + + @return Status::OK + @return Status::TimedOut if the method is timed out. + */ + Status SyncGetByName(const std::string &name, + const std::string &ray_namespace, + rpc::ActorTableData &actor_table_data, + rpc::TaskSpec &task_spec) override; + + /** + List all named actors from the GCS synchronously. + + The RPC will timeout after the default GCS RPC timeout is exceeded. + + @param all_namespaces Whether to include actors from all Ray namespaces. + @param ray_namespace The namespace to filter to if all_namespaces is false. + @param[out] actors The pair of list of named actors. Each pair includes the + namespace and name of the actor. + @return Status::OK + @return Status::TimedOut if the method is timed out. + */ + Status SyncListNamedActors( + bool all_namespaces, + const std::string &ray_namespace, + std::vector> &actors) override; + + /** + Report actor out of scope asynchronously. + + @param actor_id The ID of the actor. + @param num_restarts_due_to_lineage_reconstruction Number of restarts due to lineage + reconstruction. + @param callback Callback that will be called after the operation completes. + @param timeout_ms RPC timeout in milliseconds. -1 means the default. + */ + void AsyncReportActorOutOfScope(const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstruction, + const StatusCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Register actor to GCS asynchronously. + + @param task_spec The specification for the actor creation task. + @param callback Callback that will be called after the actor info is written to GCS. + @param timeout_ms RPC timeout ms. -1 means there's no timeout. + */ + void AsyncRegisterActor(const TaskSpecification &task_spec, + const StatusCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Restart actor for lineage reconstruction asynchronously. + + @param actor_id The ID of the actor. + @param num_restarts_due_to_lineage_reconstructions Number of restarts due to lineage + reconstructions. + @param callback Callback that will be called after the operation completes. + @param timeout_ms RPC timeout in milliseconds. -1 means the default. + */ + void AsyncRestartActorForLineageReconstruction( + const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstructions, + const StatusCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Register actor to GCS synchronously. + + The RPC will timeout after the default GCS RPC timeout is exceeded. + + @param task_spec The specification for the actor creation task. + @return Status::OK + @return Status. TimedOut if actor is not registered by the global + GCS timeout. + */ + Status SyncRegisterActor(const ray::TaskSpecification &task_spec) override; + + /** + Kill actor via GCS asynchronously. + + @param actor_id The ID of actor to destroy. + @param force_kill Whether to force kill an actor by killing the worker. + @param no_restart If set to true, the killed actor will not be restarted anymore. + @param callback Callback that will be called after the actor is destroyed. + @param timeout_ms RPC timeout in milliseconds. -1 means infinite. + */ + void AsyncKillActor(const ActorID &actor_id, + bool force_kill, + bool no_restart, + const StatusCallback &callback, + int64_t timeout_ms = -1) override; + + /** + Asynchronously request GCS to create the actor. + + This should be called after the worker has resolved the actor dependencies. + TODO(...): Currently this request will only reply after the actor is created. + We should change it to reply immediately after GCS has persisted the actor + dependencies in storage. + + @param task_spec The specification for the actor creation task. + @param callback Callback that will be called after the actor info is written to GCS. + */ + void AsyncCreateActor( + const TaskSpecification &task_spec, + const rpc::ClientCallback &callback) override; + + /** + Subscribe to any update operations of an actor. + + @param actor_id The ID of actor to be subscribed to. + @param subscribe Callback that will be called each time when the actor is updated. + @param done Callback that will be called when subscription is complete. + */ + void AsyncSubscribe(const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + /** + Cancel subscription to an actor. + + @param actor_id The ID of the actor to be unsubscribed to. + */ + void AsyncUnsubscribe(const ActorID &actor_id) override; + + /** + Reestablish subscription. + + This should be called when GCS server restarts from a failure. + PubSub server restart will cause GCS server restart. In this case, we need to + resubscribe from PubSub server, otherwise we only need to fetch data from GCS + server. + */ + void AsyncResubscribe() override; + + /** + Check if the specified actor is unsubscribed. + + @param actor_id The ID of the actor. + @return Whether the specified actor is unsubscribed. + */ + bool IsActorUnsubscribed(const ActorID &actor_id) override; + + private: + // Mutex to protect the resubscribe_operations_ field and fetch_data_operations_ field. + absl::Mutex mutex_; + + /// Resubscribe operations for actors. + absl::flat_hash_map resubscribe_operations_ + ABSL_GUARDED_BY(mutex_); + + /// Save the fetch data operation of actors. + absl::flat_hash_map fetch_data_operations_ + ABSL_GUARDED_BY(mutex_); + + GcsClientContext *context_; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h b/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h new file mode 100644 index 000000000000..3b87368d6887 --- /dev/null +++ b/src/ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h @@ -0,0 +1,216 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "ray/common/gcs_callback_types.h" +#include "ray/common/id.h" +#include "ray/common/task/task_spec.h" +#include "ray/rpc/rpc_callback_types.h" +#include "src/ray/protobuf/gcs.pb.h" +#include "src/ray/protobuf/gcs_service.pb.h" + +namespace ray { +namespace gcs { + +/** + @interface ActorInfoAccessorInterface + + Interface for accessing actor information and managing actor lifecycle + */ +class ActorInfoAccessorInterface { + public: + virtual ~ActorInfoAccessorInterface() = default; + + /** + Get actor specification asynchronously. + + @param actor_id The ID of actor to look up. + @param callback Callback that will be called after lookup finishes. + */ + virtual void AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) = 0; + + /** + Get all actor specifications asynchronously. + + @param actor_id To filter actors by actor_id. + @param job_id To filter actors by job_id. + @param actor_state_name To filter actors based on actor state. + @param callback Callback that will be called after lookup finishes. + @param timeout_ms request timeout, defaults to -1 for infinite timeout. + */ + virtual void AsyncGetAllByFilter(const std::optional &actor_id, + const std::optional &job_id, + const std::optional &actor_state_name, + const MultiItemCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Get actor specification for a named actor asynchronously. + + @param name The name of the actor to look up. + @param ray_namespace The namespace to filter to. + @param callback Callback that will be called after lookup finishes. + @param timeout_ms RPC timeout in milliseconds. -1 means the default. + */ + virtual void AsyncGetByName(const std::string &name, + const std::string &ray_namespace, + const OptionalItemCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Get actor specification for a named actor synchronously. + + @param name The name of the actor to look up. + @param ray_namespace The namespace to filter to. NotFound if the name doesn't exist. + @return Status::OK + @return Status::TimedOut if the method is timed out. + */ + virtual Status SyncGetByName(const std::string &name, + const std::string &ray_namespace, + rpc::ActorTableData &actor_table_data, + rpc::TaskSpec &task_spec) = 0; + + /** + List all named actors synchronously. + + @param all_namespaces Whether to include actors from all Ray namespaces. + @param ray_namespace The namespace to filter to if all_namespaces is false. + @param[out] actors The pair of list of named actors. Each pair includes the + namespace and name of the actor. + @return Status::OK + @return Status::TimedOut if the method is timed out. + */ + virtual Status SyncListNamedActors( + bool all_namespaces, + const std::string &ray_namespace, + std::vector> &actors) = 0; + + /** + Report actor out of scope asynchronously. + + @param actor_id The ID of the actor. + @param num_restarts_due_to_lineage_reconstruction Number of restarts due to lineage + reconstruction. + @param callback Callback that will be called after the operation completes. + @param timeout_ms Timeout in milliseconds. -1 means the default. + */ + virtual void AsyncReportActorOutOfScope( + const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstruction, + const StatusCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Register actor asynchronously. + + @param task_spec The specification for the actor creation task. + @param callback Callback that will be called after the actor info is written. + @param timeout_ms Timeout ms. -1 means there's no timeout. + */ + virtual void AsyncRegisterActor(const TaskSpecification &task_spec, + const StatusCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Restart actor for lineage reconstruction asynchronously. + + @param actor_id The ID of the actor. + @param num_restarts_due_to_lineage_reconstructions Number of restarts due to lineage + reconstructions. + @param callback Callback that will be called after the operation completes. + @param timeout_ms Timeout in milliseconds. -1 means the default. + */ + virtual void AsyncRestartActorForLineageReconstruction( + const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstructions, + const StatusCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Register actor to GCS synchronously. + + The RPC will timeout after the default GCS RPC timeout is exceeded. + + @param task_spec The specification for the actor creation task. + @return Status::OK + @return Status::TimedOut if actor is not registered by the global + GCS timeout. + */ + virtual Status SyncRegisterActor(const ray::TaskSpecification &task_spec) = 0; + + /** + Kill actor asynchronously. + + @param actor_id The ID of actor to destroy. + @param force_kill Whether to force kill an actor by killing the worker. + @param no_restart If set to true, the killed actor will not be restarted anymore. + @param callback Callback that will be called after the actor is destroyed. + @param timeout_ms RPC timeout in milliseconds. -1 means infinite. + */ + virtual void AsyncKillActor(const ActorID &actor_id, + bool force_kill, + bool no_restart, + const StatusCallback &callback, + int64_t timeout_ms = -1) = 0; + + /** + Asynchronously request to create the actor. + + @param task_spec The specification for the actor creation task. + @param callback Callback that will be called after the actor info is written. + */ + virtual void AsyncCreateActor( + const TaskSpecification &task_spec, + const rpc::ClientCallback &callback) = 0; + + /** + Subscribe to any update operations of an actor. + + @param actor_id The ID of actor to be subscribed to. + @param subscribe Callback that will be called each time when the actor is updated. + @param done Callback that will be called when subscription is complete. + */ + virtual void AsyncSubscribe( + const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + /** + Cancel subscription to an actor. + + @param actor_id The ID of the actor to be unsubscribed to. + */ + virtual void AsyncUnsubscribe(const ActorID &actor_id) = 0; + + /** + Reestablish subscription. + */ + virtual void AsyncResubscribe() = 0; + + /** + Check if the specified actor is unsubscribed. + + @param actor_id The ID of the actor. + @return Whether the specified actor is unsubscribed. + */ + virtual bool IsActorUnsubscribed(const ActorID &actor_id) = 0; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/default_accessor_factory.cc b/src/ray/gcs_rpc_client/default_accessor_factory.cc new file mode 100644 index 000000000000..277421dee0c1 --- /dev/null +++ b/src/ray/gcs_rpc_client/default_accessor_factory.cc @@ -0,0 +1,28 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs_rpc_client/default_accessor_factory.h" + +#include "ray/gcs_rpc_client/accessors/actor_info_accessor.h" + +namespace ray { +namespace gcs { + +std::unique_ptr +DefaultAccessorFactory::CreateActorInfoAccessor(GcsClientContext *context) { + return std::make_unique(context); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/default_accessor_factory.h b/src/ray/gcs_rpc_client/default_accessor_factory.h new file mode 100644 index 000000000000..b76b369c870e --- /dev/null +++ b/src/ray/gcs_rpc_client/default_accessor_factory.h @@ -0,0 +1,38 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ray/gcs_rpc_client/accessor_factory_interface.h" + +namespace ray { +namespace gcs { + +/** +@interface DefaultAccessorFactory + +Default implementation of the AccessorFactoryInterface. Creates the standard +implementation of each accessor. +*/ +class DefaultAccessorFactory : public AccessorFactoryInterface { + public: + DefaultAccessorFactory() = default; + ~DefaultAccessorFactory() override = default; + + std::unique_ptr CreateActorInfoAccessor( + GcsClientContext *context) override; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/default_gcs_client_context.cc b/src/ray/gcs_rpc_client/default_gcs_client_context.cc new file mode 100644 index 000000000000..82a94999a2b7 --- /dev/null +++ b/src/ray/gcs_rpc_client/default_gcs_client_context.cc @@ -0,0 +1,45 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs_rpc_client/default_gcs_client_context.h" + +#include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/pubsub/gcs_subscriber.h" + +namespace ray { +namespace gcs { +pubsub::GcsSubscriber &DefaultGcsClientContext::GetGcsSubscriber() { + return *subscriber_; +} + +rpc::GcsRpcClient &DefaultGcsClientContext::GetGcsRpcClient() { return *client_; } + +bool DefaultGcsClientContext::IsInitialized() const { return client_ != nullptr; } + +void DefaultGcsClientContext::SetGcsRpcClient(std::shared_ptr client) { + client_ = client; +} +void DefaultGcsClientContext::SetGcsSubscriber( + std::unique_ptr subscriber) { + subscriber_ = std::move(subscriber); +} + +void DefaultGcsClientContext::Disconnect() { + if (client_) { + client_.reset(); + } +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/default_gcs_client_context.h b/src/ray/gcs_rpc_client/default_gcs_client_context.h new file mode 100644 index 000000000000..a25cce04b48e --- /dev/null +++ b/src/ray/gcs_rpc_client/default_gcs_client_context.h @@ -0,0 +1,63 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "ray/gcs_rpc_client/gcs_client_context.h" +#include "ray/gcs_rpc_client/rpc_client.h" +#include "ray/pubsub/gcs_subscriber.h" + +namespace ray { +namespace gcs { + +class DefaultGcsClientContext : public GcsClientContext { + public: + DefaultGcsClientContext() = default; + ~DefaultGcsClientContext() override = default; + + /** + Get the GCS subscriber for pubsub operations. + */ + pubsub::GcsSubscriber &GetGcsSubscriber() override; + + /** + Get the GCS RPC client for making RPC calls. + */ + rpc::GcsRpcClient &GetGcsRpcClient() override; + + /** + Check if the RPC client has been initialized + */ + bool IsInitialized() const override; + + /** + Set the GCS RPC client for making RPC calls. + */ + void SetGcsRpcClient(std::shared_ptr client) override; + + /** + Set the GCS subscriber for pubsub operations. + */ + void SetGcsSubscriber(std::unique_ptr subscriber) override; + + void Disconnect() override; + + private: + std::shared_ptr client_ = nullptr; + std::unique_ptr subscriber_ = nullptr; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/gcs_client.cc b/src/ray/gcs_rpc_client/gcs_client.cc index 6cc2dae8f44c..a97fe6534398 100644 --- a/src/ray/gcs_rpc_client/gcs_client.cc +++ b/src/ray/gcs_rpc_client/gcs_client.cc @@ -23,6 +23,10 @@ #include "ray/common/asio/asio_util.h" #include "ray/common/ray_config.h" #include "ray/gcs_rpc_client/accessor.h" +#include "ray/gcs_rpc_client/accessor_factory_interface.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor.h" +#include "ray/gcs_rpc_client/default_accessor_factory.h" +#include "ray/gcs_rpc_client/default_gcs_client_context.h" #include "ray/pubsub/subscriber.h" #include "ray/util/network_util.h" @@ -96,10 +100,23 @@ bool GcsClientOptions::ShouldFetchClusterId(ClusterID cluster_id, GcsClient::GcsClient(GcsClientOptions options, std::string local_address, - UniqueID gcs_client_id) + UniqueID gcs_client_id, + std::unique_ptr accessor_factory, + std::unique_ptr client_context) : options_(std::move(options)), gcs_client_id_(gcs_client_id), - local_address_(std::move(local_address)) {} + local_address_(std::move(local_address)) { + if (accessor_factory == nullptr) { + accessor_factory_ = std::make_unique(); + } else { + accessor_factory_ = std::move(accessor_factory); + } + if (client_context == nullptr) { + client_context_ = std::make_unique(); + } else { + client_context_ = std::move(client_context); + } +} Status GcsClient::Connect(instrumented_io_context &io_service, int64_t timeout_ms) { if (timeout_ms < 0) { @@ -110,45 +127,42 @@ Status GcsClient::Connect(instrumented_io_context &io_service, int64_t timeout_m /*record_stats=*/false, local_address_, options_.cluster_id_); - gcs_rpc_client_ = std::make_shared( - options_.gcs_address_, options_.gcs_port_, *client_call_manager_); - - resubscribe_func_ = [this]() { - RAY_LOG(INFO) << "Resubscribing to GCS tables."; - job_accessor_->AsyncResubscribe(); - actor_accessor_->AsyncResubscribe(); - node_accessor_->AsyncResubscribe(); - worker_accessor_->AsyncResubscribe(); - }; - rpc::Address gcs_address; - gcs_address.set_ip_address(options_.gcs_address_); - gcs_address.set_port(options_.gcs_port_); - /// TODO(mwtian): refactor pubsub::Subscriber to avoid faking worker ID. - gcs_address.set_worker_id(UniqueID::FromRandom().Binary()); - - auto subscriber = std::make_unique( - /*subscriber_id=*/gcs_client_id_, - /*channels=*/ - std::vector{ - rpc::ChannelType::GCS_ACTOR_CHANNEL, - rpc::ChannelType::GCS_JOB_CHANNEL, - rpc::ChannelType::GCS_NODE_INFO_CHANNEL, - rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL, - rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL}, - /*max_command_batch_size*/ RayConfig::instance().max_command_batch_size(), - /*get_client=*/ - [this](const rpc::Address &) { - return std::make_shared(gcs_rpc_client_); - }, - /*callback_service*/ &io_service); - - // Init GCS subscriber instance. - gcs_subscriber_ = - std::make_unique(gcs_address, std::move(subscriber)); + // Only initialize the RPC client and subscriber if needed + if (!client_context_->IsInitialized()) { + auto gcs_rpc_client = std::make_shared( + options_.gcs_address_, options_.gcs_port_, *client_call_manager_); + + rpc::Address gcs_address; + gcs_address.set_ip_address(options_.gcs_address_); + gcs_address.set_port(options_.gcs_port_); + /// TODO(mwtian): refactor pubsub::Subscriber to avoid faking worker ID. + gcs_address.set_worker_id(UniqueID::FromRandom().Binary()); + + auto subscriber = std::make_unique( + /*subscriber_id=*/gcs_client_id_, + /*channels=*/ + std::vector{ + rpc::ChannelType::GCS_ACTOR_CHANNEL, + rpc::ChannelType::GCS_JOB_CHANNEL, + rpc::ChannelType::GCS_NODE_INFO_CHANNEL, + rpc::ChannelType::GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL, + rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL}, + /*max_command_batch_size*/ RayConfig::instance().max_command_batch_size(), + /*get_client=*/ + [gcs_rpc_client](const rpc::Address &) { + return std::make_shared(gcs_rpc_client); + }, + /*callback_service*/ &io_service); + + // Init GCS subscriber instance. + client_context_->SetGcsSubscriber( + std::make_unique(gcs_address, std::move(subscriber))); + client_context_->SetGcsRpcClient(gcs_rpc_client); + } + actor_accessor_ = accessor_factory_->CreateActorInfoAccessor(client_context_.get()); job_accessor_ = std::make_unique(this); - actor_accessor_ = std::make_unique(this); node_accessor_ = std::make_unique(this); node_resource_accessor_ = std::make_unique(this); error_accessor_ = std::make_unique(this); @@ -160,6 +174,14 @@ Status GcsClient::Connect(instrumented_io_context &io_service, int64_t timeout_m autoscaler_state_accessor_ = std::make_unique(this); publisher_accessor_ = std::make_unique(this); + resubscribe_func_ = [this]() { + RAY_LOG(INFO) << "Resubscribing to GCS tables."; + job_accessor_->AsyncResubscribe(); + actor_accessor_->AsyncResubscribe(); + node_accessor_->AsyncResubscribe(); + worker_accessor_->AsyncResubscribe(); + }; + RAY_LOG(DEBUG) << "GcsClient connected " << BuildAddress(options_.gcs_address_, options_.gcs_port_); @@ -177,10 +199,11 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) { rpc::GetClusterIdReply reply; RAY_LOG(DEBUG) << "Cluster ID is nil, getting cluster ID from GCS server."; - Status s = gcs_rpc_client_->SyncGetClusterId(std::move(request), &reply, timeout_ms); + Status s = client_context_->GetGcsRpcClient().SyncGetClusterId( + std::move(request), &reply, timeout_ms); if (!s.ok()) { RAY_LOG(WARNING) << "Failed to get cluster ID from GCS server: " << s; - gcs_rpc_client_.reset(); + client_context_->Disconnect(); client_call_manager_.reset(); return s; } @@ -191,13 +214,13 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) { } void GcsClient::Disconnect() { - if (gcs_rpc_client_) { - gcs_rpc_client_.reset(); + if (client_context_) { + client_context_->Disconnect(); } } std::pair GcsClient::GetGcsServerAddress() const { - return gcs_rpc_client_->GetAddress(); + return client_context_->GetGcsRpcClient().GetAddress(); } ClusterID GcsClient::GetClusterId() const { diff --git a/src/ray/gcs_rpc_client/gcs_client.h b/src/ray/gcs_rpc_client/gcs_client.h index a80f290c2d6d..14b2bbc50744 100644 --- a/src/ray/gcs_rpc_client/gcs_client.h +++ b/src/ray/gcs_rpc_client/gcs_client.h @@ -26,6 +26,9 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs_rpc_client/accessor.h" +#include "ray/gcs_rpc_client/accessor_factory_interface.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" +#include "ray/gcs_rpc_client/gcs_client_context.h" #include "ray/gcs_rpc_client/rpc_client.h" #include "ray/pubsub/gcs_subscriber.h" #include "ray/util/logging.h" @@ -93,16 +96,21 @@ class GcsClientOptions { /// Before exit, `Disconnect()` must be called. class RAY_EXPORT GcsClient : public std::enable_shared_from_this { public: - GcsClient() = default; /// Constructor of GcsClient. /// /// \param options Options for client. /// \param local_address The local address of the client. (Used to decide whether to /// inject RPC failures for testing) /// \param gcs_client_id This is used to give subscribers Unique ID's. + /// \param accessor_factory A factory which supplies accessors to this gcs_client + /// instance + /// \param client_context A context which supplies lower level functionality + /// like an rpc client and/or a subscriber explicit GcsClient(GcsClientOptions options, std::string local_address = "", - UniqueID gcs_client_id = UniqueID::FromRandom()); + UniqueID gcs_client_id = UniqueID::FromRandom(), + std::unique_ptr accessor_factory = nullptr, + std::unique_ptr client_context = nullptr); GcsClient(const GcsClient &) = delete; GcsClient &operator=(const GcsClient &) = delete; @@ -149,7 +157,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { /// Get the sub-interface for accessing actor information in GCS. /// This function is thread safe. - ActorInfoAccessor &Actors() { + ActorInfoAccessorInterface &Actors() { RAY_CHECK(actor_accessor_ != nullptr); return *actor_accessor_; } @@ -223,14 +231,20 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { /// This function is thread safe. virtual InternalKVAccessor &InternalKV() { return *internal_kv_accessor_; } - virtual pubsub::GcsSubscriber &GetGcsSubscriber() { return *gcs_subscriber_; } + virtual rpc::GcsRpcClient &GetGcsRpcClient() { + return client_context_->GetGcsRpcClient(); + } - virtual rpc::GcsRpcClient &GetGcsRpcClient() { return *gcs_rpc_client_; } + virtual pubsub::GcsSubscriber &GetGcsSubscriber() { + return client_context_->GetGcsSubscriber(); + } protected: GcsClientOptions options_; - std::unique_ptr actor_accessor_; + std::unique_ptr client_context_; + std::unique_ptr accessor_factory_; + std::unique_ptr actor_accessor_; std::unique_ptr job_accessor_; std::unique_ptr node_accessor_; std::unique_ptr node_resource_accessor_; @@ -250,10 +264,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { const UniqueID gcs_client_id_ = UniqueID::FromRandom(); - std::unique_ptr gcs_subscriber_; - // Gcs rpc client - std::shared_ptr gcs_rpc_client_; std::unique_ptr client_call_manager_; std::function resubscribe_func_; std::string local_address_; diff --git a/src/ray/gcs_rpc_client/gcs_client_context.h b/src/ray/gcs_rpc_client/gcs_client_context.h new file mode 100644 index 000000000000..550613c5a22c --- /dev/null +++ b/src/ray/gcs_rpc_client/gcs_client_context.h @@ -0,0 +1,67 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ray { + +namespace pubsub { +class GcsSubscriber; +} + +namespace rpc { +class GcsRpcClient; +} + +namespace gcs { + +/** +@class GcsClientContext +Minimal interface providing access to RPC client and subscriber. This allows accessor +implementations to access GCS services without depending on the full GcsClient class, +breaking circular dependencies. +*/ +class GcsClientContext { + public: + virtual ~GcsClientContext() = default; + + /** + Get the GCS subscriber for pubsub operations. + */ + virtual pubsub::GcsSubscriber &GetGcsSubscriber() = 0; + + /** + Get the GCS RPC client for making RPC calls. + */ + virtual rpc::GcsRpcClient &GetGcsRpcClient() = 0; + + /** + Check if the RPC client has been initialized + */ + virtual bool IsInitialized() const = 0; + + /** + Set the GCS RPC client for making RPC calls. + */ + virtual void SetGcsRpcClient(std::shared_ptr client) = 0; + /** + Set the GCS subscriber for pubsub operations. + */ + virtual void SetGcsSubscriber(std::unique_ptr subscriber) = 0; + + virtual void Disconnect() = 0; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/global_state_accessor.cc b/src/ray/gcs_rpc_client/global_state_accessor.cc index 56f7e067fde5..fefb228343a0 100644 --- a/src/ray/gcs_rpc_client/global_state_accessor.cc +++ b/src/ray/gcs_rpc_client/global_state_accessor.cc @@ -377,7 +377,7 @@ std::unique_ptr GlobalStateAccessor::GetInternalKV(const std::strin absl::ReaderMutexLock lock(&mutex_); std::string value; - Status status = gcs_client_->InternalKV().Get(ns, key, GetGcsTimeoutMs(), value); + Status status = gcs_client_->InternalKV().Get(ns, key, rpc::GetGcsTimeoutMs(), value); return status.ok() ? std::make_unique(value) : nullptr; } diff --git a/src/ray/gcs_rpc_client/rpc_client.h b/src/ray/gcs_rpc_client/rpc_client.h index 22fe8745f2b3..f5981d2b3897 100644 --- a/src/ray/gcs_rpc_client/rpc_client.h +++ b/src/ray/gcs_rpc_client/rpc_client.h @@ -618,6 +618,10 @@ class GcsRpcClient { friend class GcsClientReconnectionTest; FRIEND_TEST(GcsClientReconnectionTest, ReconnectionBackoff); }; +inline int64_t GetGcsTimeoutMs() { + return absl::ToInt64Milliseconds( + absl::Seconds(::RayConfig::instance().gcs_server_request_timeout_seconds())); +} } // namespace rpc } // namespace ray diff --git a/src/ray/gcs_rpc_client/tests/BUILD.bazel b/src/ray/gcs_rpc_client/tests/BUILD.bazel index d8dfd62e842a..232c8a6702ab 100644 --- a/src/ray/gcs_rpc_client/tests/BUILD.bazel +++ b/src/ray/gcs_rpc_client/tests/BUILD.bazel @@ -100,3 +100,16 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "gcs_client_injectable_test", + size = "small", + srcs = [ + "gcs_client_injectable_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/gcs_rpc_client:gcs_client", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc new file mode 100644 index 000000000000..8781d533f9c6 --- /dev/null +++ b/src/ray/gcs_rpc_client/tests/gcs_client_injectable_test.cc @@ -0,0 +1,185 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "ray/common/id.h" +#include "ray/gcs_rpc_client/accessor_factory_interface.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" +#include "ray/gcs_rpc_client/gcs_client.h" +#include "ray/gcs_rpc_client/gcs_client_context.h" + +namespace ray { + +namespace pubsub { +class GcsSubscriber; +} + +namespace rpc { +class GcsRpcClient; +} + +namespace gcs { + +// Mock GcsRpcClient - empty class for testing +class TestGcsRpcClient { + public: + TestGcsRpcClient() = default; +}; + +// Mock GcsSubscriber - empty class for testing +class TestGcsSubscriber { + public: + TestGcsSubscriber() = default; +}; + +// Test implementation of GcsClientContext +class FakeGcsClientContext : public GcsClientContext { + public: + FakeGcsClientContext(std::shared_ptr rpc_client, + std::unique_ptr subscriber) + : rpc_client_(rpc_client), subscriber_(std::move(subscriber)) {} + + pubsub::GcsSubscriber &GetGcsSubscriber() override { + // Cast our mock to the expected type + return *reinterpret_cast(subscriber_.get()); + } + + rpc::GcsRpcClient &GetGcsRpcClient() override { + // Cast our mock to the expected type + return *reinterpret_cast(rpc_client_.get()); + } + + bool IsInitialized() const override { return true; } + + void Disconnect() override {} + + void SetGcsRpcClient(std::shared_ptr client) override {} + + void SetGcsSubscriber(std::unique_ptr subscriber) override {} + + private: + std::shared_ptr rpc_client_; + std::unique_ptr subscriber_; +}; + +// Test NodeInfoAccessor implementation +class TestActorInfoAccessor : public ActorInfoAccessorInterface { + public: + explicit TestActorInfoAccessor(GcsClientContext *client_impl) : is_fake_(true) {} + ~TestActorInfoAccessor() override = default; + + bool IsFake() const { return is_fake_; } + + void AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) override {} + void AsyncGetAllByFilter(const std::optional &actor_id, + const std::optional &job_id, + const std::optional &actor_state_name, + const MultiItemCallback &callback, + int64_t timeout_ms = -1) override {} + void AsyncGetByName(const std::string &name, + const std::string &ray_namespace, + const OptionalItemCallback &callback, + int64_t timeout_ms = -1) override {} + Status SyncGetByName(const std::string &name, + const std::string &ray_namespace, + rpc::ActorTableData &actor_table_data, + rpc::TaskSpec &task_spec) override { + return Status::OK(); + } + Status SyncListNamedActors( + bool all_namespaces, + const std::string &ray_namespace, + std::vector> &actors) override { + return Status::OK(); + } + void AsyncReportActorOutOfScope(const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstruction, + const StatusCallback &callback, + int64_t timeout_ms = -1) override {} + void AsyncRegisterActor(const TaskSpecification &task_spec, + const StatusCallback &callback, + int64_t timeout_ms = -1) override {} + void AsyncRestartActorForLineageReconstruction( + const ActorID &actor_id, + uint64_t num_restarts_due_to_lineage_reconstructions, + const StatusCallback &callback, + int64_t timeout_ms = -1) override {} + Status SyncRegisterActor(const ray::TaskSpecification &task_spec) override { + return Status::OK(); + } + void AsyncKillActor(const ActorID &actor_id, + bool force_kill, + bool no_restart, + const StatusCallback &callback, + int64_t timeout_ms = -1) override {} + void AsyncCreateActor( + const TaskSpecification &task_spec, + const rpc::ClientCallback &callback) override {} + void AsyncSubscribe(const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override {} + void AsyncResubscribe() override {} + void AsyncUnsubscribe(const ActorID &actor_id) override {} + bool IsActorUnsubscribed(const ActorID &actor_id) override { return false; } + + private: + bool is_fake_; +}; + +// Custom AccessorFactory that provides FakeActorInfoAccessor +class MixedAccessorFactory : public AccessorFactoryInterface { + public: + MixedAccessorFactory() = default; + ~MixedAccessorFactory() override = default; + + std::unique_ptr CreateActorInfoAccessor( + GcsClientContext *client_impl) override { + // Return mock implementation + return std::make_unique(client_impl); + } +}; + +TEST(GcsClientInjectableTest, AccessorFactoryReturnsInjectedAccessorIfDefaultOverriden) { + // Create mock RPC client and subscriber + auto rpc_client = std::make_shared(); + auto subscriber = std::make_unique(); + + // Create GCS client context + auto context = + std::make_unique(rpc_client, std::move(subscriber)); + + // Create custom accessor factory + auto factory = std::make_unique(); + + // Create GcsClient with custom context and factory + GcsClientOptions options; + GcsClient gcs_client( + options, "", UniqueID::FromRandom(), std::move(factory), std::move(context)); + + // Connect the client + instrumented_io_context io_service; + Status status = gcs_client.Connect(io_service, -1); + ASSERT_TRUE(status.ok()); + + // Verify that NodeInfoAccessor is the fake implementation + auto &actor_accessor = gcs_client.Actors(); + auto fake_actor_accessor = dynamic_cast(&actor_accessor); + ASSERT_NE(fake_actor_accessor, nullptr); + EXPECT_TRUE(fake_actor_accessor->IsFake()); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_reconnection_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_reconnection_test.cc index 2fa4e184acd3..c412f030ad45 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_reconnection_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_reconnection_test.cc @@ -24,7 +24,6 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/test_utils.h" #include "ray/gcs/gcs_server.h" -#include "ray/gcs_rpc_client/accessor.h" #include "ray/gcs_rpc_client/gcs_client.h" #include "ray/gcs_rpc_client/rpc_client.h" #include "ray/observability/fake_metric.h" @@ -228,7 +227,7 @@ TEST_F(GcsClientReconnectionTest, ReconnectionBasic) { std::promise p0; auto f0 = p0.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p0](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p0](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p0.set_value(); })); @@ -241,7 +240,7 @@ TEST_F(GcsClientReconnectionTest, ReconnectionBasic) { std::promise p1; auto f1 = p1.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVGet( - "", "A", gcs::GetGcsTimeoutMs(), [&p1](auto status, auto p) { + "", "A", rpc::GetGcsTimeoutMs(), [&p1](auto status, auto p) { ASSERT_TRUE(status.ok()) << status.ToString(); p1.set_value(*p); })); @@ -276,7 +275,7 @@ TEST_F(GcsClientReconnectionTest, ReconnectionBackoff) { std::promise p1; auto f1 = p1.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p1](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p1](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p1.set_value(); })); @@ -290,7 +289,7 @@ TEST_F(GcsClientReconnectionTest, ReconnectionBackoff) { std::promise p2; auto f2 = p2.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p2](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p2](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p2.set_value(); })); @@ -347,7 +346,7 @@ TEST_F(GcsClientReconnectionTest, QueueingAndBlocking) { std::promise p1; auto f1 = p1.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p1](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p1](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p1.set_value(); })); @@ -359,7 +358,7 @@ TEST_F(GcsClientReconnectionTest, QueueingAndBlocking) { std::promise p2; auto f2 = p2.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p2](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p2](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p2.set_value(); })); @@ -375,7 +374,7 @@ TEST_F(GcsClientReconnectionTest, QueueingAndBlocking) { std::promise p4; auto f4 = p4.get_future(); RAY_UNUSED(client->InternalKV().AsyncInternalKVPut( - "", "A", "B", false, gcs::GetGcsTimeoutMs(), [&p4](auto status, auto) { + "", "A", "B", false, rpc::GetGcsTimeoutMs(), [&p4](auto status, auto) { ASSERT_TRUE(status.ok()) << status.ToString(); p4.set_value(); })); @@ -406,17 +405,17 @@ TEST_F(GcsClientReconnectionTest, Timeout) { auto client = CreateGCSClient(); bool added = false; ASSERT_TRUE( - client->InternalKV().Put("", "A", "B", false, gcs::GetGcsTimeoutMs(), added).ok()); + client->InternalKV().Put("", "A", "B", false, rpc::GetGcsTimeoutMs(), added).ok()); ASSERT_TRUE(added); ShutdownGCS(); std::vector values; ASSERT_TRUE( - client->InternalKV().Keys("", "A", gcs::GetGcsTimeoutMs(), values).IsTimedOut()); + client->InternalKV().Keys("", "A", rpc::GetGcsTimeoutMs(), values).IsTimedOut()); ASSERT_TRUE(values.empty()); StartGCS(); - ASSERT_TRUE(client->InternalKV().Keys("", "A", gcs::GetGcsTimeoutMs(), values).ok()); + ASSERT_TRUE(client->InternalKV().Keys("", "A", rpc::GetGcsTimeoutMs(), values).ok()); ASSERT_EQ(std::vector{"A"}, values); } diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index ec5106418fc4..7d1efd6f1279 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -24,7 +24,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/test_utils.h" #include "ray/gcs/gcs_server.h" -#include "ray/gcs_rpc_client/accessor.h" +#include "ray/gcs_rpc_client/accessors/actor_info_accessor.h" #include "ray/gcs_rpc_client/rpc_client.h" #include "ray/observability/fake_metric.h" #include "ray/util/network_util.h" @@ -403,7 +403,7 @@ class GcsClientTest : public ::testing::TestWithParam { nodes = std::move(result); promise.set_value(status.ok()); }, - gcs::GetGcsTimeoutMs()); + rpc::GetGcsTimeoutMs()); EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); return nodes; } diff --git a/src/ray/raylet/tests/local_object_manager_test.cc b/src/ray/raylet/tests/local_object_manager_test.cc index f03b980adb14..a91807ecde7f 100644 --- a/src/ray/raylet/tests/local_object_manager_test.cc +++ b/src/ray/raylet/tests/local_object_manager_test.cc @@ -30,7 +30,6 @@ #include "ray/common/id.h" #include "ray/core_worker_rpc_client/core_worker_client_pool.h" #include "ray/core_worker_rpc_client/fake_core_worker_client.h" -#include "ray/gcs_rpc_client/accessor.h" #include "ray/object_manager/ownership_object_directory.h" #include "ray/observability/fake_metric.h" #include "ray/pubsub/subscriber.h" diff --git a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc index 1104763d2a66..6dc002604a54 100644 --- a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc +++ b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc @@ -77,7 +77,9 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { class MockGcsClient : public gcs::GcsClient { public: - explicit MockGcsClient(bool is_subscribed_to_node_change) { + explicit MockGcsClient(bool is_subscribed_to_node_change, + gcs::GcsClientOptions &options) + : GcsClient(options) { this->node_accessor_ = std::make_unique(is_subscribed_to_node_change); } @@ -91,7 +93,12 @@ class DefaultUnavailableTimeoutCallbackTest : public ::testing::TestWithParam([this](const Address &addr) { return std::make_shared( @@ -100,6 +107,7 @@ class DefaultUnavailableTimeoutCallbackTest : public ::testing::TestWithParam raylet_client_pool_; };