Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/ray/raylet_rpc_client/raylet_client_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ std::function<void()> RayletClientPool::GetDefaultUnavailableTimeoutCallback(
const NodeID node_id = NodeID::FromBinary(addr.node_id());

auto gcs_check_node_alive = [node_id, addr, raylet_client_pool, gcs_client]() {
gcs_client->Nodes().AsyncGetAll(
[addr, node_id, raylet_client_pool](const Status &status,
std::vector<rpc::GcsNodeInfo> &&nodes) {
gcs_client->Nodes().AsyncGetAllNodeAddressAndLiveness(
[addr, node_id, raylet_client_pool](
const Status &status, std::vector<rpc::GcsNodeAddressAndLiveness> &&nodes) {
if (!status.ok()) {
// Will try again when unavailable timeout callback is retried.
RAY_LOG(INFO) << "Failed to get node info from GCS";
Expand All @@ -56,7 +56,8 @@ std::function<void()> RayletClientPool::GetDefaultUnavailableTimeoutCallback(
};

if (gcs_client->Nodes().IsSubscribedToNodeChange()) {
auto *node_info = gcs_client->Nodes().Get(node_id, /*filter_dead_nodes=*/false);
auto *node_info = gcs_client->Nodes().GetNodeAddressAndLiveness(
node_id, /*filter_dead_nodes=*/false);
if (node_info == nullptr) {
// Node could be dead or info may have not made it to the subscriber cache yet.
// Check with the GCS to confirm if the node is dead.
Expand Down
52 changes: 32 additions & 20 deletions src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor {

bool IsSubscribedToNodeChange() const override { return is_subscribed_to_node_change_; }

MOCK_METHOD(const GcsNodeInfo *, Get, (const NodeID &, bool), (const, override));
MOCK_METHOD(const rpc::GcsNodeAddressAndLiveness *,
GetNodeAddressAndLiveness,
(const NodeID &, bool),
(const, override));

MOCK_METHOD(void,
AsyncGetAll,
(const gcs::MultiItemCallback<GcsNodeInfo> &,
AsyncGetAllNodeAddressAndLiveness,
(const gcs::MultiItemCallback<rpc::GcsNodeAddressAndLiveness> &,
int64_t,
const std::vector<NodeID> &),
(override));
Expand Down Expand Up @@ -118,13 +121,16 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) {
// had to discard to keep its cache size in check, should disconnect.

auto &mock_node_accessor = gcs_client_.MockNodeAccessor();
auto invoke_with_node_info_vector = [](std::vector<GcsNodeInfo> node_info_vector) {
return Invoke([node_info_vector](const gcs::MultiItemCallback<GcsNodeInfo> &callback,
int64_t,
const std::vector<NodeID> &) {
callback(Status::OK(), node_info_vector);
});
};
auto invoke_with_node_info_vector =
[](std::vector<GcsNodeAddressAndLiveness> node_info_vector) {
return Invoke(
[node_info_vector](
const gcs::MultiItemCallback<rpc::GcsNodeAddressAndLiveness> &callback,
int64_t,
const std::vector<NodeID> &) {
callback(Status::OK(), node_info_vector);
});
};

auto raylet_client_1_address = CreateRandomAddress("1");
auto raylet_client_2_address = CreateRandomAddress("2");
Expand All @@ -140,33 +146,39 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) {
ASSERT_TRUE(
CheckRayletClientPoolHasClient(*raylet_client_pool_, raylet_client_2_node_id));

GcsNodeInfo node_info_alive;
GcsNodeAddressAndLiveness node_info_alive;
node_info_alive.set_state(GcsNodeInfo::ALIVE);
GcsNodeInfo node_info_dead;
GcsNodeAddressAndLiveness node_info_dead;
node_info_dead.set_state(GcsNodeInfo::DEAD);
if (is_subscribed_to_node_change_) {
EXPECT_CALL(mock_node_accessor,
Get(raylet_client_1_node_id, /*filter_dead_nodes=*/false))
EXPECT_CALL(
mock_node_accessor,
GetNodeAddressAndLiveness(raylet_client_1_node_id, /*filter_dead_nodes=*/false))
.WillOnce(Return(nullptr))
.WillOnce(Return(&node_info_alive))
.WillOnce(Return(&node_info_dead));
EXPECT_CALL(mock_node_accessor,
AsyncGetAll(_, _, std::vector<NodeID>{raylet_client_1_node_id}))
AsyncGetAllNodeAddressAndLiveness(
_, _, std::vector<NodeID>{raylet_client_1_node_id}))
.WillOnce(invoke_with_node_info_vector({node_info_alive}));
EXPECT_CALL(mock_node_accessor,
Get(raylet_client_2_node_id, /*filter_dead_nodes=*/false))
EXPECT_CALL(
mock_node_accessor,
GetNodeAddressAndLiveness(raylet_client_2_node_id, /*filter_dead_nodes=*/false))
.WillOnce(Return(nullptr));
EXPECT_CALL(mock_node_accessor,
AsyncGetAll(_, _, std::vector<NodeID>{raylet_client_2_node_id}))
AsyncGetAllNodeAddressAndLiveness(
_, _, std::vector<NodeID>{raylet_client_2_node_id}))
.WillOnce(invoke_with_node_info_vector({}));
} else {
EXPECT_CALL(mock_node_accessor,
AsyncGetAll(_, _, std::vector<NodeID>{raylet_client_1_node_id}))
AsyncGetAllNodeAddressAndLiveness(
_, _, std::vector<NodeID>{raylet_client_1_node_id}))
.WillOnce(invoke_with_node_info_vector({node_info_alive}))
.WillOnce(invoke_with_node_info_vector({node_info_alive}))
.WillOnce(invoke_with_node_info_vector({node_info_dead}));
EXPECT_CALL(mock_node_accessor,
AsyncGetAll(_, _, std::vector<NodeID>{raylet_client_2_node_id}))
AsyncGetAllNodeAddressAndLiveness(
_, _, std::vector<NodeID>{raylet_client_2_node_id}))
.WillOnce(invoke_with_node_info_vector({}));
}

Expand Down