Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ cdef extern from "ray/common/python_callbacks.h" namespace "ray":
void (object, object) nogil,
object) nogil

cdef extern from "ray/gcs_rpc_client/accessor.h" nogil:
cdef cppclass CActorInfoAccessor "ray::gcs::ActorInfoAccessor":
cdef extern from "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h" nogil:
cdef cppclass CActorInfoAccessorInterface "ray::gcs::ActorInfoAccessorInterface":
void AsyncGetAllByFilter(
const optional[CActorID] &actor_id,
const optional[CJobID] &job_id,
Expand All @@ -438,6 +438,7 @@ cdef extern from "ray/gcs_rpc_client/accessor.h" nogil:
const StatusPyCallback &callback,
int64_t timeout_ms)

cdef extern from "ray/gcs_rpc_client/accessor.h" nogil:
cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor":
CRayStatus GetAll(
const optional[c_string] &job_or_submission_id,
Expand Down Expand Up @@ -649,7 +650,7 @@ cdef extern from "ray/gcs_rpc_client/gcs_client.h" nogil:
c_pair[c_string, int] GetGcsServerAddress() const
CClusterID GetClusterId() const

CActorInfoAccessor& Actors()
CActorInfoAccessorInterface& Actors()
CJobInfoAccessor& Jobs()
CInternalKVAccessor& InternalKV()
CNodeInfoAccessor& Nodes()
Expand Down
62 changes: 0 additions & 62 deletions src/mock/ray/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,6 @@
namespace ray {
namespace gcs {

class MockActorInfoAccessor : public ActorInfoAccessor {
public:
MOCK_METHOD(void,
AsyncGet,
(const ActorID &actor_id,
const OptionalItemCallback<rpc::ActorTableData> &callback),
(override));
MOCK_METHOD(void,
AsyncGetAllByFilter,
(const std::optional<ActorID> &actor_id,
const std::optional<JobID> &job_id,
const std::optional<std::string> &actor_state_name,
const MultiItemCallback<rpc::ActorTableData> &callback,
int64_t timeout_ms),
(override));
MOCK_METHOD(void,
AsyncGetByName,
(const std::string &name,
const std::string &ray_namespace,
const OptionalItemCallback<rpc::ActorTableData> &callback,
int64_t timeout_ms),
(override));
MOCK_METHOD(void,
AsyncRegisterActor,
(const TaskSpecification &task_spec,
const StatusCallback &callback,
int64_t timeout_ms),
(override));
MOCK_METHOD(Status,
SyncRegisterActor,
(const TaskSpecification &task_spec),
(override));
MOCK_METHOD(void,
AsyncKillActor,
(const ActorID &actor_id,
bool force_kill,
bool no_restart,
const StatusCallback &callback,
int64_t timeout_ms),
(override));
MOCK_METHOD(void,
AsyncCreateActor,
(const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(override));
MOCK_METHOD(void,
AsyncSubscribe,
(const ActorID &actor_id,
(const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe),
const StatusCallback &done),
(override));
MOCK_METHOD(void, AsyncUnsubscribe, (const ActorID &actor_id), (override));
MOCK_METHOD(void, AsyncResubscribe, (), (override));
MOCK_METHOD(bool, IsActorUnsubscribed, (const ActorID &actor_id), (override));
};

} // namespace gcs
} // namespace ray

namespace ray {
namespace gcs {

class MockJobInfoAccessor : public JobInfoAccessor {
public:
MOCK_METHOD(void,
Expand Down
138 changes: 138 additions & 0 deletions src/mock/ray/gcs_client/accessors/actor_info_accessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2025 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "absl/container/flat_hash_map.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/common/task/task_spec.h"
#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h"
#include "src/ray/protobuf/gcs.pb.h"

namespace ray {
namespace gcs {

class FakeActorInfoAccessor : public gcs::ActorInfoAccessorInterface {
public:
FakeActorInfoAccessor() = default;

~FakeActorInfoAccessor() {}

// Stub implementations for interface methods not used by this test
void AsyncGet(const ActorID &,
const gcs::OptionalItemCallback<rpc::ActorTableData> &) override {}
void AsyncGetAllByFilter(const std::optional<ActorID> &,
const std::optional<JobID> &,
const std::optional<std::string> &,
const gcs::MultiItemCallback<rpc::ActorTableData> &,
int64_t = -1) override {}
void AsyncGetByName(const std::string &,
const std::string &,
const gcs::OptionalItemCallback<rpc::ActorTableData> &,
int64_t = -1) override {}
Status SyncGetByName(const std::string &,
const std::string &,
rpc::ActorTableData &,
rpc::TaskSpec &) override {
return Status::OK();
}
Status SyncListNamedActors(
bool,
const std::string &,
std::vector<std::pair<std::string, std::string>> &) override {
return Status::OK();
}
void AsyncReportActorOutOfScope(const ActorID &,
uint64_t,
const gcs::StatusCallback &,
int64_t = -1) override {}
void AsyncRegisterActor(const TaskSpecification &task_spec,
const gcs::StatusCallback &callback,
int64_t = -1) override {
async_register_actor_callback_ = callback;
}
void AsyncRestartActorForLineageReconstruction(const ActorID &,
uint64_t,
const gcs::StatusCallback &,
int64_t = -1) override {}
Status SyncRegisterActor(const TaskSpecification &) override { return Status::OK(); }
void AsyncKillActor(
const ActorID &, bool, bool, const gcs::StatusCallback &, int64_t = -1) override {}
void AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) override {
async_create_actor_callback_ = callback;
}

void AsyncSubscribe(
const ActorID &actor_id,
const gcs::SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe,
const gcs::StatusCallback &done) override {
auto callback_entry = std::make_pair(actor_id, subscribe);
callback_map_.emplace(actor_id, subscribe);
subscribe_finished_callback_map_[actor_id] = done;
actor_subscribed_times_[actor_id]++;
}

void AsyncUnsubscribe(const ActorID &) override {}
void AsyncResubscribe() override {}
bool IsActorUnsubscribed(const ActorID &) override { return false; }

bool ActorStateNotificationPublished(const ActorID &actor_id,
const rpc::ActorTableData &actor_data) {
auto it = callback_map_.find(actor_id);
if (it == callback_map_.end()) return false;
auto actor_state_notification_callback = it->second;
auto copied = actor_data;
actor_state_notification_callback(actor_id, std::move(copied));
return true;
}

bool CheckSubscriptionRequested(const ActorID &actor_id) {
return callback_map_.find(actor_id) != callback_map_.end();
}

// Mock the logic of subscribe finished. see `ActorInfoAccessor::AsyncSubscribe`
bool ActorSubscribeFinished(const ActorID &actor_id,
const rpc::ActorTableData &actor_data) {
auto subscribe_finished_callback_it = subscribe_finished_callback_map_.find(actor_id);
if (subscribe_finished_callback_it == subscribe_finished_callback_map_.end()) {
return false;
}

auto copied = actor_data;
if (!ActorStateNotificationPublished(actor_id, std::move(copied))) {
return false;
}

auto subscribe_finished_callback = subscribe_finished_callback_it->second;
subscribe_finished_callback(Status::OK());
// Erase callback when actor subscribe is finished.
subscribe_finished_callback_map_.erase(subscribe_finished_callback_it);
return true;
}

absl::flat_hash_map<ActorID, gcs::SubscribeCallback<ActorID, rpc::ActorTableData>>
callback_map_;
absl::flat_hash_map<ActorID, gcs::StatusCallback> subscribe_finished_callback_map_;
absl::flat_hash_map<ActorID, uint32_t> actor_subscribed_times_;

// Callbacks for AsyncCreateActor and AsyncRegisterActor
rpc::ClientCallback<rpc::CreateActorReply> async_create_actor_callback_;
gcs::StatusCallback async_register_actor_callback_;
};

} // namespace gcs
} // namespace ray
7 changes: 4 additions & 3 deletions src/mock/ray/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "accessors/actor_info_accessor.h"
#include "mock/ray/gcs_client/accessor.h"
#include "ray/gcs_rpc_client/gcs_client.h"

Expand All @@ -40,9 +41,9 @@ class MockGcsClient : public GcsClient {
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
MOCK_METHOD(std::string, DebugString, (), (const, override));

MockGcsClient() {
MockGcsClient() : GcsClient(MockGcsClientOptions()) {
mock_job_accessor = new MockJobInfoAccessor();
mock_actor_accessor = new MockActorInfoAccessor();
mock_actor_accessor = new FakeActorInfoAccessor();
mock_node_accessor = new MockNodeInfoAccessor();
mock_node_resource_accessor = new MockNodeResourceInfoAccessor();
mock_error_accessor = new MockErrorInfoAccessor();
Expand All @@ -61,7 +62,7 @@ class MockGcsClient : public GcsClient {
GcsClient::internal_kv_accessor_.reset(mock_internal_kv_accessor);
GcsClient::task_accessor_.reset(mock_task_accessor);
}
MockActorInfoAccessor *mock_actor_accessor;
FakeActorInfoAccessor *mock_actor_accessor;
MockJobInfoAccessor *mock_job_accessor;
MockNodeInfoAccessor *mock_node_accessor;
MockNodeResourceInfoAccessor *mock_node_resource_accessor;
Expand Down
6 changes: 3 additions & 3 deletions src/ray/core_worker/actor_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <utility>
#include <vector>

#include "ray/gcs_rpc_client/accessor.h"
#include "ray/gcs_rpc_client/accessors/actor_info_accessor_interface.h"
#include "ray/util/thread_utils.h"

namespace ray {
Expand Down Expand Up @@ -74,7 +74,7 @@ class ActorCreatorInterface {

class ActorCreator : public ActorCreatorInterface {
public:
explicit ActorCreator(gcs::ActorInfoAccessor &actor_client)
explicit ActorCreator(gcs::ActorInfoAccessorInterface &actor_client)
: actor_client_(actor_client) {}

Status RegisterActor(const TaskSpecification &task_spec) const override;
Expand All @@ -101,7 +101,7 @@ class ActorCreator : public ActorCreatorInterface {
const rpc::ClientCallback<rpc::CreateActorReply> &callback) override;

private:
gcs::ActorInfoAccessor &actor_client_;
gcs::ActorInfoAccessorInterface &actor_client_;
using RegisteringActorType =
absl::flat_hash_map<ActorID, std::vector<ray::gcs::StatusCallback>>;
ThreadPrivate<RegisteringActorType> registering_actors_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,9 @@ TEST_F(DirectTaskTransportTest, ActorCreationOk) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto creation_task_spec = GetActorCreationTaskSpec(actor_id);
EXPECT_CALL(*task_manager, CompletePendingTask(creation_task_spec.TaskId(), _, _, _));
rpc::ClientCallback<rpc::CreateActorReply> create_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncCreateActor(creation_task_spec, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&create_cb)));
actor_task_submitter->SubmitActorCreationTask(creation_task_spec);
create_cb(Status::OK(), rpc::CreateActorReply());
gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::OK(),
rpc::CreateActorReply());
}

TEST_F(DirectTaskTransportTest, ActorCreationFail) {
Expand All @@ -108,12 +105,9 @@ TEST_F(DirectTaskTransportTest, ActorCreationFail) {
*task_manager,
FailPendingTask(
creation_task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED, _, _));
rpc::ClientCallback<rpc::CreateActorReply> create_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncCreateActor(creation_task_spec, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&create_cb)));
actor_task_submitter->SubmitActorCreationTask(creation_task_spec);
create_cb(Status::IOError(""), rpc::CreateActorReply());
gcs_client->mock_actor_accessor->async_create_actor_callback_(Status::IOError(""),
rpc::CreateActorReply());
}

TEST_F(DirectTaskTransportTest, ActorRegisterFailure) {
Expand All @@ -125,10 +119,6 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) {
auto task_arg = task_spec.GetMutableMessage().add_args();
auto inline_obj_ref = task_arg->add_nested_inlined_refs();
inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary());
std::function<void(Status)> register_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(creation_task_spec, ::testing::_, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&register_cb)));
actor_creator->AsyncRegisterActor(creation_task_spec, nullptr);
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
actor_task_submitter->AddActorQueueIfNotExists(actor_id,
Expand All @@ -141,7 +131,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterFailure) {
*task_manager,
FailOrRetryPendingTask(
task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _, _));
register_cb(Status::IOError(""));
gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::IOError(""));
}

TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
Expand All @@ -153,10 +143,6 @@ TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
auto task_arg = task_spec.GetMutableMessage().add_args();
auto inline_obj_ref = task_arg->add_nested_inlined_refs();
inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary());
std::function<void(Status)> register_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(creation_task_spec, ::testing::_, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&register_cb)));
actor_creator->AsyncRegisterActor(creation_task_spec, nullptr);
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
actor_task_submitter->AddActorQueueIfNotExists(actor_id,
Expand All @@ -166,7 +152,7 @@ TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
/*owned*/ false);
ASSERT_TRUE(CheckSubmitTask(task_spec));
EXPECT_CALL(*task_manager, FailOrRetryPendingTask(_, _, _, _, _, _)).Times(0);
register_cb(Status::OK());
gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK());
}

} // namespace core
Expand Down
12 changes: 2 additions & 10 deletions src/ray/core_worker/tests/actor_creator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,15 @@ TEST_F(ActorCreatorTest, IsRegister) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
auto task_spec = GetTaskSpec(actor_id);
std::function<void(Status)> cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(task_spec, ::testing::_, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&cb)));
actor_creator->AsyncRegisterActor(task_spec, nullptr);
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
cb(Status::OK());
gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK());
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
}

TEST_F(ActorCreatorTest, AsyncWaitForFinish) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto task_spec = GetTaskSpec(actor_id);
std::function<void(Status)> cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(::testing::_, ::testing::_, ::testing::_))
.WillRepeatedly(::testing::DoAll(::testing::SaveArg<1>(&cb)));
int count = 0;
auto per_finish_cb = [&count](Status status) {
ASSERT_TRUE(status.ok());
Expand All @@ -76,7 +68,7 @@ TEST_F(ActorCreatorTest, AsyncWaitForFinish) {
for (int i = 0; i < 10; ++i) {
actor_creator->AsyncWaitForActorRegisterFinish(actor_id, per_finish_cb);
}
cb(Status::OK());
gcs_client->mock_actor_accessor->async_register_actor_callback_(Status::OK());
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
ASSERT_EQ(11, count);
}
Expand Down
Loading