Skip to content

Commit

Permalink
[2/n] Stabilize GCS/Autoscaler interface: Drain and Kill Node API (#3…
Browse files Browse the repository at this point in the history
…2002)

This PR adds a DrainAndKillNode endpoint to the monitor service. It has the exact same semantics as the GcsNodeManager::HandleDrainNode.


---------

Co-authored-by: Alex <alex@anyscale.com>
  • Loading branch information
Alex Wu and Alex authored Jan 30, 2023
1 parent 56b7911 commit e331f6e
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 7 deletions.
51 changes: 51 additions & 0 deletions python/ray/tests/test_monitor_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import binascii
import pytest

import ray
import grpc
from ray.core.generated import monitor_pb2, monitor_pb2_grpc
from ray.cluster_utils import Cluster


@pytest.fixture
Expand All @@ -12,7 +14,56 @@ def monitor_stub(ray_start_regular_shared):
return monitor_pb2_grpc.MonitorGcsServiceStub(channel)


@pytest.fixture
def monitor_stub_with_cluster():
cluster = Cluster()
cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()

channel = grpc.insecure_channel(cluster.gcs_address)
stub = monitor_pb2_grpc.MonitorGcsServiceStub(channel)

cluster.connect()

yield stub, cluster
ray.shutdown()
cluster.shutdown()


def test_ray_version(monitor_stub):
request = monitor_pb2.GetRayVersionRequest()
response = monitor_stub.GetRayVersion(request)
assert response.version == ray.__version__


def count_live_nodes():
return sum(1 for node in ray.nodes() if node["Alive"])


def test_drain_and_kill_node(monitor_stub_with_cluster):
monitor_stub, cluster = monitor_stub_with_cluster

head_node = ray.nodes()[0]["NodeID"]

cluster.add_node(num_cpus=2)
cluster.wait_for_nodes()

assert count_live_nodes() == 2

node_ids = {node["NodeID"] for node in ray.nodes()}
worker_nodes = node_ids - {head_node}
assert len(worker_nodes) == 1

worker_node_id = next(iter(worker_nodes))

request = monitor_pb2.DrainAndKillNodeRequest(
node_ids=[binascii.unhexlify(worker_node_id)]
)
response = monitor_stub.DrainAndKillNode(request)

assert response.drained_nodes == request.node_ids
assert count_live_nodes() == 1

response = monitor_stub.DrainAndKillNode(request)
assert response.drained_nodes == request.node_ids
assert count_live_nodes() == 1
1 change: 1 addition & 0 deletions src/mock/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class MockGcsNodeManager : public GcsNodeManager {
rpc::GetInternalConfigReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void, DrainNode, (const NodeID &node_id), (override));
};

} // namespace gcs
Expand Down
15 changes: 14 additions & 1 deletion src/ray/gcs/gcs_server/gcs_monitor_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
namespace ray {
namespace gcs {

GcsMonitorServer::GcsMonitorServer() {}
GcsMonitorServer::GcsMonitorServer(std::shared_ptr<GcsNodeManager> gcs_node_manager)
: gcs_node_manager_(gcs_node_manager) {}

void GcsMonitorServer::HandleGetRayVersion(rpc::GetRayVersionRequest request,
rpc::GetRayVersionReply *reply,
Expand All @@ -28,5 +29,17 @@ void GcsMonitorServer::HandleGetRayVersion(rpc::GetRayVersionRequest request,
send_reply_callback(Status::OK(), nullptr, nullptr);
}

void GcsMonitorServer::HandleDrainAndKillNode(
rpc::DrainAndKillNodeRequest request,
rpc::DrainAndKillNodeReply *reply,
rpc::SendReplyCallback send_reply_callback) {
for (const auto &node_id_bytes : request.node_ids()) {
const auto node_id = NodeID::FromBinary(node_id_bytes);
gcs_node_manager_->DrainNode(node_id);
*reply->add_drained_nodes() = node_id_bytes;
}
send_reply_callback(Status::OK(), nullptr, nullptr);
}

} // namespace gcs
} // namespace ray
10 changes: 9 additions & 1 deletion src/ray/gcs/gcs_server/gcs_monitor_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "ray/gcs/gcs_server/gcs_node_manager.h"
#include "ray/rpc/gcs_server/gcs_rpc_server.h"

namespace ray {
Expand All @@ -23,11 +24,18 @@ namespace gcs {
/// GCS and `monitor.py`
class GcsMonitorServer : public rpc::MonitorServiceHandler {
public:
explicit GcsMonitorServer();
explicit GcsMonitorServer(std::shared_ptr<GcsNodeManager> gcs_node_manager);

void HandleGetRayVersion(rpc::GetRayVersionRequest request,
rpc::GetRayVersionReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

void HandleDrainAndKillNode(rpc::DrainAndKillNodeRequest request,
rpc::DrainAndKillNodeReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

private:
std::shared_ptr<GcsNodeManager> gcs_node_manager_;
};
} // namespace gcs
} // namespace ray
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/gcs_node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ void GcsNodeManager::HandleDrainNode(rpc::DrainNodeRequest request,
const auto &node_drain_request = request.drain_node_data(i);
const auto node_id = NodeID::FromBinary(node_drain_request.node_id());

RAY_LOG(INFO) << "Draining node info, node id = " << node_id;
DrainNode(node_id);
auto drain_node_status = reply->add_drain_node_status();
drain_node_status->set_node_id(node_id.Binary());
Expand All @@ -86,6 +85,7 @@ void GcsNodeManager::HandleDrainNode(rpc::DrainNodeRequest request,
}

void GcsNodeManager::DrainNode(const NodeID &node_id) {
RAY_LOG(INFO) << "Draining node info, node id = " << node_id;
auto node = RemoveNode(node_id, /* is_intended = */ true);
if (!node) {
RAY_LOG(INFO) << "Node " << node_id << " is already removed";
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/gcs_node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class GcsNodeManager : public rpc::NodeInfoHandler {

/// Drain the given node.
/// Idempotent.
void DrainNode(const NodeID &node_id);
virtual void DrainNode(const NodeID &node_id);

private:
/// Add the dead node to the cache. If the cache is full, the earliest dead node is
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ void GcsServer::InitGcsTaskManager() {
}

void GcsServer::InitMonitorServer() {
monitor_server_ = std::make_unique<GcsMonitorServer>();
monitor_server_ = std::make_unique<GcsMonitorServer>(gcs_node_manager_);
monitor_grpc_service_.reset(
new rpc::MonitorGrpcService(main_service_, *monitor_server_));
rpc_server_.RegisterService(*monitor_grpc_service_);
Expand Down
24 changes: 22 additions & 2 deletions src/ray/gcs/gcs_server/test/gcs_monitor_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
#include "ray/gcs/gcs_server/test/gcs_server_test_util.h"
#include "ray/gcs/test/gcs_test_util.h"
#include "ray/gcs/gcs_server/gcs_monitor_server.h"
#include "mock/ray/pubsub/publisher.h"
#include "mock/ray/gcs/gcs_server/gcs_node_manager.h"
// clang-format on

using namespace testing;

namespace ray {
class GcsMonitorServerTest : public ::testing::Test {
public:
GcsMonitorServerTest() : monitor_server_() {}
GcsMonitorServerTest()
: mock_node_manager_(std::make_shared<gcs::MockGcsNodeManager>()),
monitor_server_(mock_node_manager_) {}

protected:
std::shared_ptr<gcs::MockGcsNodeManager> mock_node_manager_;
gcs::GcsMonitorServer monitor_server_;
};

Expand All @@ -43,4 +48,19 @@ TEST_F(GcsMonitorServerTest, TestRayVersion) {
ASSERT_EQ(reply.version(), kRayVersion);
}

TEST_F(GcsMonitorServerTest, TestDrainAndKillNode) {
rpc::DrainAndKillNodeRequest request;
rpc::DrainAndKillNodeReply reply;
auto send_reply_callback =
[](ray::Status status, std::function<void()> f1, std::function<void()> f2) {};

*request.add_node_ids() = NodeID::FromRandom().Binary();
*request.add_node_ids() = NodeID::FromRandom().Binary();

EXPECT_CALL(*mock_node_manager_, DrainNode(_)).Times(Exactly(2));
monitor_server_.HandleDrainAndKillNode(request, &reply, send_reply_callback);

ASSERT_EQ(reply.drained_nodes().size(), 2);
}

} // namespace ray
13 changes: 13 additions & 0 deletions src/ray/protobuf/monitor.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,22 @@ message GetRayVersionReply {
string version = 1;
}

message DrainAndKillNodeRequest {
// The node ids to drain.
repeated bytes node_ids = 1;
}

message DrainAndKillNodeReply {
// The node ids which are beginning to drain.
repeated bytes drained_nodes = 2;
}

// This service provides a stable interface for a monitor/autoscaler process to interact
// with Ray.
service MonitorGcsService {
// Get the ray version of the service.
rpc GetRayVersion(GetRayVersionRequest) returns (GetRayVersionReply);
// Request that GCS drain and kill a node. This call is idempotent, and could
// need to be retried if the head node fails.
rpc DrainAndKillNode(DrainAndKillNodeRequest) returns (DrainAndKillNodeReply);
}
5 changes: 5 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ class MonitorGcsServiceHandler {
virtual void HandleGetRayVersion(GetRayVersionRequest request,
GetRayVersionReply *reply,
SendReplyCallback send_reply_callback) = 0;

virtual void HandleDrainAndKillNode(DrainAndKillNodeRequest request,
DrainAndKillNodeReply *reply,
SendReplyCallback send_reply_callback) = 0;
};

/// The `GrpcService` for `MonitorServer`.
Expand All @@ -241,6 +245,7 @@ class MonitorGrpcService : public GrpcService {
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
MONITOR_SERVICE_RPC_HANDLER(GetRayVersion);
MONITOR_SERVICE_RPC_HANDLER(DrainAndKillNode);
}

private:
Expand Down

0 comments on commit e331f6e

Please sign in to comment.