Skip to content

Commit 6563648

Browse files
dayshahsampan
authored andcommitted
[core] Revamp GetAllNodeInfo to be more efficient and take multiple NodeID's (#55115)
Needed for #55112 as GetAllNodeInfo needs to be able to take multiple node id's as arguments so you can get information for multiple nodes in one rpc to the GCS. Revamping the request here to make it make more sense with this model where you can select multiple nodes or get all nodes if you don't use any selectors. Also adding an "optimized path" where if all the selectors are just node id's, we can just do lookups in the map and add them to the reply instead of iterating through all of it. Also adding some extra tests for new behavior. --------- Signed-off-by: dayshah <dhyey2019@gmail.com> Signed-off-by: sampan <sampan@anyscale.com>
1 parent 2ddcc8c commit 6563648

File tree

14 files changed

+218
-145
lines changed

14 files changed

+218
-145
lines changed

python/ray/includes/common.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil:
455455
void AsyncGetAll(
456456
const MultiItemPyCallback[CGcsNodeInfo] &callback,
457457
int64_t timeout_ms,
458-
optional[CNodeID] node_id)
458+
c_vector[CNodeID] node_ids)
459459

460460
cdef cppclass CNodeResourceInfoAccessor "ray::gcs::NodeResourceInfoAccessor":
461461
CRayStatus GetAllResourceUsage(

python/ray/includes/gcs_client.pxi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,18 +337,18 @@ cdef class InnerGcsClient:
337337
) -> Future[Dict[NodeID, gcs_pb2.GcsNodeInfo]]:
338338
cdef:
339339
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
340-
optional[CNodeID] c_node_id
340+
c_vector[CNodeID] c_node_ids
341341
fut = incremented_fut()
342342
if node_id:
343-
c_node_id = (<NodeID>node_id).native()
343+
c_node_ids.push_back((<NodeID>node_id).native())
344344
with nogil:
345345
self.inner.get().Nodes().AsyncGetAll(
346346
MultiItemPyCallback[CGcsNodeInfo](
347347
convert_get_all_node_info,
348348
assign_and_decrement_fut,
349349
fut),
350350
timeout_ms,
351-
c_node_id)
351+
c_node_ids)
352352
return asyncio.wrap_future(fut)
353353

354354
#############################################################

python/ray/util/state/state_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,26 +312,33 @@ async def get_all_node_info(
312312
if filters is None:
313313
filters = []
314314

315-
req_filters = GetAllNodeInfoRequest.Filters()
315+
node_selectors = []
316+
state_filter = None
316317
for filter in filters:
317318
key, predicate, value = filter
318319
if predicate != "=":
319320
# We only support EQUAL predicate for source side filtering.
320321
continue
321322

322323
if key == "node_id":
323-
req_filters.node_id = NodeID(hex_to_binary(value)).binary()
324+
node_selector = GetAllNodeInfoRequest.NodeSelector()
325+
node_selector.node_id = NodeID(hex_to_binary(value)).binary()
326+
node_selectors.append(node_selector)
324327
elif key == "state":
325328
value = value.upper()
326329
if value not in GcsNodeInfo.GcsNodeState.keys():
327330
raise ValueError(f"Invalid node state for filtering: {value}")
328-
req_filters.state = GcsNodeInfo.GcsNodeState.Value(value)
331+
state_filter = GcsNodeInfo.GcsNodeState.Value(value)
329332
elif key == "node_name":
330-
req_filters.node_name = value
333+
node_selector = GetAllNodeInfoRequest.NodeSelector()
334+
node_selector.node_name = value
335+
node_selectors.append(node_selector)
331336
else:
332337
continue
333338

334-
request = GetAllNodeInfoRequest(limit=limit, filters=req_filters)
339+
request = GetAllNodeInfoRequest(
340+
limit=limit, node_selectors=node_selectors, state_filter=state_filter
341+
)
335342
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
336343
return reply
337344

src/mock/ray/gcs/gcs_client/accessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class MockNodeInfoAccessor : public NodeInfoAccessor {
140140
AsyncGetAll,
141141
(const MultiItemCallback<rpc::GcsNodeInfo> &callback,
142142
int64_t timeout_ms,
143-
std::optional<NodeID> node_id),
143+
const std::vector<NodeID> &node_ids),
144144
(override));
145145
MOCK_METHOD(void,
146146
AsyncSubscribeToNodeChange,

src/ray/gcs/gcs_client/accessor.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,11 @@ Status NodeInfoAccessor::DrainNodes(const std::vector<NodeID> &node_ids,
594594

595595
void NodeInfoAccessor::AsyncGetAll(const MultiItemCallback<rpc::GcsNodeInfo> &callback,
596596
int64_t timeout_ms,
597-
std::optional<NodeID> node_id) {
597+
const std::vector<NodeID> &node_ids) {
598598
RAY_LOG(DEBUG) << "Getting information of all nodes.";
599599
rpc::GetAllNodeInfoRequest request;
600-
if (node_id) {
601-
request.mutable_filters()->set_node_id(node_id->Binary());
600+
for (const auto &node_id : node_ids) {
601+
request.add_node_selectors()->set_node_id(node_id.Binary());
602602
}
603603
client_impl_->GetGcsRpcClient().GetAllNodeInfo(
604604
request,
@@ -683,9 +683,12 @@ Status NodeInfoAccessor::GetAllNoCache(int64_t timeout_ms,
683683
}
684684

685685
StatusOr<std::vector<rpc::GcsNodeInfo>> NodeInfoAccessor::GetAllNoCacheWithFilters(
686-
int64_t timeout_ms, rpc::GetAllNodeInfoRequest_Filters filters) {
686+
int64_t timeout_ms,
687+
rpc::GcsNodeInfo::GcsNodeState state_filter,
688+
rpc::GetAllNodeInfoRequest::NodeSelector node_selector) {
687689
rpc::GetAllNodeInfoRequest request;
688-
*request.mutable_filters() = std::move(filters);
690+
*request.add_node_selectors() = std::move(node_selector);
691+
request.set_state_filter(state_filter);
689692
rpc::GetAllNodeInfoReply reply;
690693
RAY_RETURN_NOT_OK(
691694
client_impl_->GetGcsRpcClient().SyncGetAllNodeInfo(request, &reply, timeout_ms));

src/ray/gcs/gcs_client/accessor.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,11 @@ class NodeInfoAccessor {
349349
///
350350
/// \param callback Callback that will be called after lookup finishes.
351351
/// \param timeout_ms The timeout for this request.
352-
/// \param node_id If not nullopt, only return the node info of the specified node.
352+
/// \param node_ids If this is not empty, only return the node info of the specified
353+
/// nodes.
353354
virtual void AsyncGetAll(const MultiItemCallback<rpc::GcsNodeInfo> &callback,
354355
int64_t timeout_ms,
355-
std::optional<NodeID> node_id = std::nullopt);
356+
const std::vector<NodeID> &node_ids = {});
356357

357358
/// Subscribe to node addition and removal events from GCS and cache those information.
358359
///
@@ -393,7 +394,9 @@ class NodeInfoAccessor {
393394
///
394395
/// \return All nodes that match the given filters from the gcs without the cache.
395396
virtual StatusOr<std::vector<rpc::GcsNodeInfo>> GetAllNoCacheWithFilters(
396-
int64_t timeout_ms, rpc::GetAllNodeInfoRequest_Filters filters);
397+
int64_t timeout_ms,
398+
rpc::GcsNodeInfo::GcsNodeState state_filter,
399+
rpc::GetAllNodeInfoRequest::NodeSelector node_selector);
397400

398401
/// Send a check alive request to GCS for the liveness of some nodes.
399402
///

src/ray/gcs/gcs_client/global_state_accessor.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -408,16 +408,15 @@ ray::Status GlobalStateAccessor::GetNode(const std::string &node_id_hex_str,
408408

409409
std::vector<rpc::GcsNodeInfo> node_infos;
410410
while (true) {
411-
rpc::GetAllNodeInfoRequest_Filters filters;
412-
filters.set_state(rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_ALIVE);
413-
filters.set_node_id(node_id_binary);
411+
rpc::GetAllNodeInfoRequest::NodeSelector selector;
412+
selector.set_node_id(node_id_binary);
414413
{
415414
absl::ReaderMutexLock lock(&mutex_);
416415
auto timeout_ms =
417416
std::max(end_time_point - current_time_ms(), static_cast<int64_t>(0));
418-
RAY_ASSIGN_OR_RETURN(
419-
node_infos,
420-
gcs_client_->Nodes().GetAllNoCacheWithFilters(timeout_ms, std::move(filters)));
417+
RAY_ASSIGN_OR_RETURN(node_infos,
418+
gcs_client_->Nodes().GetAllNoCacheWithFilters(
419+
timeout_ms, rpc::GcsNodeInfo::ALIVE, std::move(selector)));
421420
}
422421
if (!node_infos.empty()) {
423422
*node_info = node_infos[0].SerializeAsString();
@@ -442,16 +441,16 @@ ray::Status GlobalStateAccessor::GetNodeToConnectForDriver(
442441
current_time_ms() + RayConfig::instance().raylet_start_wait_time_s() * 1000;
443442

444443
std::vector<rpc::GcsNodeInfo> node_infos;
445-
rpc::GetAllNodeInfoRequest_Filters filters;
446-
filters.set_state(rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_ALIVE);
447-
filters.set_node_ip_address(node_ip_address);
444+
rpc::GetAllNodeInfoRequest::NodeSelector selector;
445+
selector.set_node_ip_address(node_ip_address);
448446
while (true) {
449447
{
450448
absl::ReaderMutexLock lock(&mutex_);
451449
auto timeout_ms =
452450
std::max(end_time_point - current_time_ms(), static_cast<int64_t>(0));
453-
RAY_ASSIGN_OR_RETURN(
454-
node_infos, gcs_client_->Nodes().GetAllNoCacheWithFilters(timeout_ms, filters));
451+
RAY_ASSIGN_OR_RETURN(node_infos,
452+
gcs_client_->Nodes().GetAllNoCacheWithFilters(
453+
timeout_ms, rpc::GcsNodeInfo::ALIVE, selector));
455454
}
456455
if (!node_infos.empty()) {
457456
*node_to_connect = node_infos[0].SerializeAsString();
@@ -464,22 +463,23 @@ ray::Status GlobalStateAccessor::GetNodeToConnectForDriver(
464463
auto [address, _] = gcs_client_->GetGcsServerAddress();
465464
gcs_address = std::move(address);
466465
}
467-
filters.set_node_ip_address(gcs_address);
466+
selector.set_node_ip_address(gcs_address);
468467
{
469468
absl::ReaderMutexLock lock(&mutex_);
470469
auto timeout_ms = end_time_point - current_time_ms();
471-
RAY_ASSIGN_OR_RETURN(
472-
node_infos, gcs_client_->Nodes().GetAllNoCacheWithFilters(timeout_ms, filters));
470+
RAY_ASSIGN_OR_RETURN(node_infos,
471+
gcs_client_->Nodes().GetAllNoCacheWithFilters(
472+
timeout_ms, rpc::GcsNodeInfo::ALIVE, selector));
473473
}
474474
if (node_infos.empty() && node_ip_address == gcs_address) {
475-
filters.set_node_ip_address("127.0.0.1");
475+
selector.set_node_ip_address("127.0.0.1");
476476
{
477477
absl::ReaderMutexLock lock(&mutex_);
478478
auto timeout_ms =
479479
std::max(end_time_point - current_time_ms(), static_cast<int64_t>(0));
480-
RAY_ASSIGN_OR_RETURN(
481-
node_infos,
482-
gcs_client_->Nodes().GetAllNoCacheWithFilters(timeout_ms, filters));
480+
RAY_ASSIGN_OR_RETURN(node_infos,
481+
gcs_client_->Nodes().GetAllNoCacheWithFilters(
482+
timeout_ms, rpc::GcsNodeInfo::ALIVE, selector));
483483
}
484484
}
485485
if (!node_infos.empty()) {

0 commit comments

Comments
 (0)