Skip to content

Commit f76f633

Browse files
committed
[core] add cython wrapper for raylet client
Signed-off-by: tianyi-ge <tianyig@outlook.com>
1 parent 5927027 commit f76f633

File tree

11 files changed

+138
-61
lines changed

11 files changed

+138
-61
lines changed

python/ray/_raylet.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ include "includes/libcoreworker.pxi"
199199
include "includes/global_state_accessor.pxi"
200200
include "includes/metric.pxi"
201201
include "includes/setproctitle.pxi"
202+
include "includes/raylet_client.pxi"
202203

203204
import ray
204205
from ray.exceptions import (

python/ray/dashboard/consts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@
103103
"RAY_DASHBOARD_SUBPROCESS_MODULE_WAIT_READY_TIMEOUT", 30.0
104104
)
105105

106-
NODE_MANAGER_RPC_TIMEOUT_SECONDS = 1
106+
RAYLET_RPC_TIMEOUT_SECONDS = 1

python/ray/dashboard/modules/reporter/reporter_agent.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import ray._private.prometheus_exporter as prometheus_exporter
2626
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
2727
import ray.dashboard.utils as dashboard_utils
28-
from ray._common.network_utils import build_address, parse_address
28+
from ray._common.network_utils import parse_address
2929
from ray._common.utils import (
3030
get_or_create_event_loop,
3131
get_user_temp_dir,
@@ -34,18 +34,15 @@
3434
from ray._private.metrics_agent import Gauge, MetricsAgent, Record
3535
from ray._private.ray_constants import (
3636
DEBUG_AUTOSCALING_STATUS,
37-
GLOBAL_GRPC_OPTIONS,
3837
RAY_ENABLE_OPEN_TELEMETRY,
3938
env_integer,
4039
)
4140
from ray._private.telemetry.open_telemetry_metric_recorder import (
4241
OpenTelemetryMetricRecorder,
4342
)
4443
from ray._private.utils import get_system_memory
45-
from ray._raylet import GCS_PID_KEY, WorkerID
44+
from ray._raylet import GCS_PID_KEY, RayletClient, WorkerID
4645
from ray.core.generated import (
47-
node_manager_pb2,
48-
node_manager_pb2_grpc,
4946
reporter_pb2,
5047
reporter_pb2_grpc,
5148
)
@@ -56,8 +53,8 @@
5653
COMPONENT_METRICS_TAG_KEYS,
5754
GCS_RPC_TIMEOUT_SECONDS,
5855
GPU_TAG_KEYS,
59-
NODE_MANAGER_RPC_TIMEOUT_SECONDS,
6056
NODE_TAG_KEYS,
57+
RAYLET_RPC_TIMEOUT_SECONDS,
6158
TPU_TAG_KEYS,
6259
)
6360
from ray.dashboard.modules.reporter.gpu_profile_manager import GpuProfilingManager
@@ -490,10 +487,6 @@ def __init__(self, dashboard_agent):
490487
# Create GPU metric provider instance
491488
self._gpu_metric_provider = GpuMetricProvider()
492489

493-
self._node_manager_address = build_address(
494-
self._ip, self._dashboard_agent.node_manager_port
495-
)
496-
497490
async def GetTraceback(self, request, context):
498491
pid = request.pid
499492
native = request.native
@@ -896,17 +889,14 @@ def _get_disk_io_stats():
896889
stats.write_count,
897890
)
898891

899-
async def _get_worker_pids_from_raylet(self):
900-
channel = ray._private.utils.init_grpc_channel(
901-
self._node_manager_address, GLOBAL_GRPC_OPTIONS, asynchronous=True
892+
def _get_worker_pids_from_raylet(self) -> Optional[List[int]]:
893+
# Get worker pids from raylet via gRPC.
894+
timeout = RAYLET_RPC_TIMEOUT_SECONDS * 1000 # in milliseconds
895+
raylet_client = RayletClient(
896+
ip_address=self._ip, port=self._dashboard_agent.node_manager_port
902897
)
903-
timeout = NODE_MANAGER_RPC_TIMEOUT_SECONDS
904-
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
905898
try:
906-
reply = await stub.GetDriverAndWorkerPids(
907-
node_manager_pb2.GetDriverAndWorkerPidsRequest(), timeout=timeout
908-
)
909-
return reply.pids
899+
return raylet_client.get_worker_pids(timeout=timeout)
910900
except Exception as e:
911901
logger.debug(f"Failed to get worker pids from raylet via gRPC: {e}")
912902
return None
@@ -920,7 +910,7 @@ def _generate_worker_key(self, proc: psutil.Process) -> Tuple[int, float]:
920910
return (proc.pid, proc.create_time())
921911

922912
def _get_worker_processes(self):
923-
pids = asyncio.run(self._get_worker_pids_from_raylet())
913+
pids = self._get_worker_pids_from_raylet()
924914
if pids is not None:
925915
workers = {}
926916
for pid in pids:
@@ -931,25 +921,6 @@ def _get_worker_processes(self):
931921
continue
932922
return workers
933923

934-
logger.debug("fallback to get worker processes from raylet children")
935-
raylet_proc = self._get_raylet_proc()
936-
if raylet_proc is None:
937-
return []
938-
else:
939-
workers = {}
940-
if sys.platform == "win32":
941-
# windows, get the child process not the runner
942-
for child in raylet_proc.children():
943-
if child.children():
944-
child = child.children()[0]
945-
workers[self._generate_worker_key(child)] = child
946-
else:
947-
workers = {
948-
self._generate_worker_key(proc): proc
949-
for proc in raylet_proc.children()
950-
}
951-
return workers
952-
953924
def _get_workers(self, gpus: Optional[List[GpuUtilizationInfo]] = None):
954925
workers = self._get_worker_processes()
955926
if not workers:

python/ray/includes/common.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,11 @@ cdef extern from "src/ray/protobuf/autoscaler.pb.h" nogil:
763763
void ParseFromString(const c_string &serialized)
764764
const c_string &SerializeAsString() const
765765

766+
cdef extern from "ray/raylet_rpc_client/raylet_client.h" nogil:
767+
cdef cppclass CRayletClient "ray::rpc::RayletClient":
768+
CRayletClient(const c_string &ip_address, int port)
769+
CRayStatus GetWorkerPIDs(c_vector[int32_t] &worker_pids, int64_t timeout_ms)
770+
766771
cdef extern from "ray/common/task/task_spec.h" nogil:
767772
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
768773
CConcurrencyGroup(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from libcpp.vector cimport vector as c_vector
2+
from libcpp.string cimport string as c_string
3+
from libc.stdint cimport int32_t as c_int32_t
4+
from libcpp.memory cimport unique_ptr, make_unique
5+
from ray.includes.common cimport CRayletClient, CRayStatus, CAddress
6+
7+
cdef class RayletClient:
8+
cdef:
9+
unique_ptr[CRayletClient] inner
10+
11+
def __cinit__(self, ip_address: str, port: int):
12+
cdef:
13+
c_string c_ip_address
14+
c_int32_t c_port
15+
c_ip_address = ip_address.encode('utf-8')
16+
c_port = <int32_t>port
17+
self.inner = make_unique[CRayletClient](c_ip_address, c_port)
18+
19+
cdef list get_worker_pids(self, timeout_ms: int):
20+
cdef:
21+
c_vector[c_int32_t] pids
22+
CRayStatus status
23+
assert self.inner.get() is not NULL
24+
status = self.inner.get().GetWorkerPIDs(pids, timeout_ms)
25+
if not status.ok():
26+
raise RuntimeError("Failed to get worker PIDs from raylet: " + status.message())
27+
return [pid for pid in pids]

src/ray/protobuf/node_manager.proto

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,11 @@ message IsLocalWorkerDeadReply {
399399
bool is_dead = 1;
400400
}
401401

402-
message GetDriverAndWorkerPidsRequest {}
402+
message GetWorkerPIDsRequest {}
403403

404-
message GetDriverAndWorkerPidsReply {
404+
message GetWorkerPIDsReply {
405405
// PIDs of all drivers and workers managed by the local raylet.
406-
repeated uint32 pids = 1;
406+
repeated int32 pids = 1;
407407
}
408408

409409
// Service for inter-node-manager communication.
@@ -520,8 +520,8 @@ service NodeManagerService {
520520
// Failure: Is currently only used when grpc channel is unavailable for retryable core
521521
// worker clients. The unavailable callback will eventually be retried so if this fails.
522522
rpc IsLocalWorkerDead(IsLocalWorkerDeadRequest) returns (IsLocalWorkerDeadReply);
523-
// Get the worker managed by local raylet.
524-
// Failure: Sends to local raylet, so should never fail.
525-
rpc GetDriverAndWorkerPids(GetDriverAndWorkerPidsRequest)
526-
returns (GetDriverAndWorkerPidsReply);
523+
// Get the PIDs of all workers currently alive that are managed by the local Raylet.
524+
// This includes connected driver processes.
525+
// Failure: Will retry on failure with logging
526+
rpc GetWorkerPIDs(GetWorkerPIDsRequest) returns (GetWorkerPIDsReply);
527527
}

src/ray/raylet/node_manager.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2723,14 +2723,15 @@ void NodeManager::TriggerGlobalGC() {
27232723
should_local_gc_ = true;
27242724
}
27252725

2726-
void NodeManager::HandleGetDriverAndWorkerPids(
2727-
rpc::GetDriverAndWorkerPidsRequest request,
2728-
rpc::GetDriverAndWorkerPidsReply *reply,
2729-
rpc::SendReplyCallback send_reply_callback) {
2730-
auto all_workers = worker_pool_.GetAllRegisteredWorkers(/* filter_dead_worker */ true);
2731-
const auto &drivers =
2732-
worker_pool_.GetAllRegisteredDrivers(/* filter_dead_driver */ true);
2733-
all_workers.insert(all_workers.end(), drivers.begin(), drivers.end());
2726+
void NodeManager::HandleGetWorkerPIDs(rpc::GetWorkerPIDsRequest request,
2727+
rpc::GetWorkerPIDsReply *reply,
2728+
rpc::SendReplyCallback send_reply_callback) {
2729+
auto all_workers = worker_pool_.GetAllRegisteredWorkers(/* filter_dead_worker */ true,
2730+
/* filter_io_workers */ true);
2731+
auto drivers = worker_pool_.GetAllRegisteredDrivers(/* filter_dead_driver */ true);
2732+
all_workers.insert(all_workers.end(),
2733+
std::make_move_iterator(drivers.begin()),
2734+
std::make_move_iterator(drivers.end()));
27342735
for (const auto &worker : all_workers) {
27352736
reply->add_pids(worker->GetProcess().GetId());
27362737
}

src/ray/raylet/node_manager.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
636636
rpc::NotifyGCSRestartReply *reply,
637637
rpc::SendReplyCallback send_reply_callback) override;
638638

639+
/// Handle a `GetWorkerPIDs` request.
640+
void NodeManager::HandleGetWorkerPIDs(
641+
rpc::GetWorkerPIDsRequest request,
642+
rpc::GetWorkerPIDsReply *reply,
643+
rpc::SendReplyCallback send_reply_callback) override;
644+
639645
/// Checks the local socket connection for all registered workers and drivers.
640646
/// If any of them have disconnected unexpectedly (i.e., we receive a SIGHUP),
641647
/// we disconnect and kill the worker process.

src/ray/raylet_rpc_client/raylet_client.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,32 @@ RayletClient::RayletClient(const rpc::Address &address,
4747
std::move(raylet_unavailable_timeout_callback),
4848
/*server_name=*/std::string("Raylet ") + address.ip_address())) {}
4949

50+
RayletClient::RayletClient(const std::string &ip_address, int port) {
51+
io_service_ = std::make_unique<instrumented_io_context>();
52+
client_call_manager_ =
53+
std::make_unique<rpc::ClientCallManager>(*io_service_, /*record_stats=*/false);
54+
grpc_client_ = std::make_unique<rpc::GrpcClient<rpc::NodeManagerService>>(
55+
ip_address, port, *client_call_manager_);
56+
auto raylet_unavailable_timeout_callback = []() {
57+
RAY_LOG(WARNING) << "Raylet is unavailable for "
58+
<< ::RayConfig::instance().raylet_rpc_server_reconnect_timeout_s()
59+
<< "s";
60+
};
61+
retryable_grpc_client_ = rpc::RetryableGrpcClient::Create(
62+
grpc_client_->Channel(),
63+
client_call_manager_->GetMainService(),
64+
/*max_pending_requests_bytes=*/
65+
std::numeric_limits<uint64_t>::max(),
66+
/*check_channel_status_interval_milliseconds=*/
67+
::RayConfig::instance().grpc_client_check_connection_status_interval_milliseconds(),
68+
/*server_unavailable_timeout_seconds=*/
69+
::RayConfig::instance().raylet_rpc_server_reconnect_timeout_s(),
70+
/*server_unavailable_timeout_callback=*/
71+
raylet_unavailable_timeout_callback,
72+
/*server_name=*/
73+
std::string("Raylet ") + ip_address);
74+
}
75+
5076
void RayletClient::RequestWorkerLease(
5177
const rpc::LeaseSpec &lease_spec,
5278
bool grant_or_reject,
@@ -465,5 +491,30 @@ void RayletClient::GetNodeStats(
465491
/*method_timeout_ms*/ -1);
466492
}
467493

494+
Status RayletClient::GetWorkerPIDs(std::vector<int32_t> &worker_pids,
495+
int64_t timeout_ms) {
496+
rpc::GetWorkerPIDsRequest request;
497+
std::promise<Status> promise;
498+
auto future = promise.get_future();
499+
auto callback = [&promise, &worker_pids](const Status &status,
500+
rpc::GetWorkerPIDsReply &&reply) {
501+
if (status.ok()) {
502+
worker_pids = std::vector<int32_t>(reply.pids().begin(), reply.pids().end());
503+
}
504+
promise.set_value(status);
505+
};
506+
INVOKE_RPC_CALL(NodeManagerService,
507+
GetWorkerPIDs,
508+
request,
509+
callback,
510+
grpc_client_,
511+
/*method_timeout_ms*/ timeout_ms);
512+
if (future.wait_for(std::chrono::milliseconds(timeout_ms)) ==
513+
std::future_status::timeout) {
514+
return Status::TimedOut("Timed out getting worker PIDs from raylet");
515+
}
516+
return future.get();
517+
}
518+
468519
} // namespace rpc
469520
} // namespace ray

src/ray/raylet_rpc_client/raylet_client.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,20 @@ class RayletClient : public RayletClientInterface {
4343
/// Connect to the raylet.
4444
///
4545
/// \param address The IP address of the worker.
46-
/// \param port The port that the worker should listen on for gRPC requests. If
47-
/// 0, the worker should choose a random port.
4846
/// \param client_call_manager The client call manager to use for the grpc connection.
47+
/// \param raylet_unavailable_timeout_callback callback to be called when the raylet is
48+
/// unavailable for a certain period of time.
4949
explicit RayletClient(const rpc::Address &address,
5050
rpc::ClientCallManager &client_call_manager,
5151
std::function<void()> raylet_unavailable_timeout_callback);
5252

53+
/// Connect to the raylet. only used for cython wrapper `CRayletClient`
54+
/// `client_call_manager` will be created inside.
55+
///
56+
/// \param ip_address The IP address of the worker.
57+
/// \param port The port of the worker.
58+
explicit RayletClient(const std::string &ip_address, int port);
59+
5360
std::shared_ptr<grpc::Channel> GetChannel() const override;
5461

5562
void RequestWorkerLease(
@@ -163,6 +170,8 @@ class RayletClient : public RayletClientInterface {
163170
void GetNodeStats(const rpc::GetNodeStatsRequest &request,
164171
const rpc::ClientCallback<rpc::GetNodeStatsReply> &callback) override;
165172

173+
Status GetWorkerPIDs(std::vector<int32_t> &worker_pids, int64_t timeout_ms);
174+
166175
private:
167176
/// gRPC client to the NodeManagerService.
168177
std::shared_ptr<rpc::GrpcClient<rpc::NodeManagerService>> grpc_client_;
@@ -177,6 +186,12 @@ class RayletClient : public RayletClientInterface {
177186

178187
/// The number of object ID pin RPCs currently in flight.
179188
std::atomic<int64_t> pins_in_flight_ = 0;
189+
190+
private:
191+
/// if io context and client call manager are created inside the raylet client, they
192+
/// should be kept active during the whole lifetime of client.
193+
std::unique_ptr<instrumented_io_context> io_service_;
194+
std::unique_ptr<rpc::ClientCallManager> client_call_manager_;
180195
};
181196

182197
} // namespace rpc

0 commit comments

Comments
 (0)