Skip to content

Commit

Permalink
[GCS] Optimize GetAllJobInfo API for performance (ray-project#47530)
Browse files Browse the repository at this point in the history
Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
liuxsh9 authored and ujjawal-khare committed Oct 15, 2024
1 parent 4038aea commit 0e567a1
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 76 deletions.
10 changes: 9 additions & 1 deletion python/ray/_private/gcs_aio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,18 @@ async def internal_kv_keys(

async def get_all_job_info(
self,
*,
job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None,
) -> Dict[JobID, gcs_pb2.JobTableData]:
"""
Return dict key: bytes of job_id; value: JobTableData pb message.
"""
return await self._async_proxy.get_all_job_info(job_or_submission_id, timeout)
return await self._async_proxy.get_all_job_info(
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=skip_submission_job_info_field,
skip_is_running_tasks_field=skip_is_running_tasks_field,
timeout=timeout,
)
9 changes: 6 additions & 3 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2910,8 +2910,10 @@ cdef class OldGcsClient:
return result

@_auto_reconnect
def get_all_job_info(self, job_or_submission_id: str = None,
timeout=None) -> Dict[JobID, JobTableData]:
def get_all_job_info(
self, *, job_or_submission_id: str = None, skip_submission_job_info_field=False,
skip_is_running_tasks_field=False, timeout=None
) -> Dict[JobID, JobTableData]:
# Ideally we should use json_format.MessageToDict(job_info),
# but `job_info` is a cpp pb message not a python one.
# Manually converting each and every protobuf field is out of question,
Expand All @@ -2931,7 +2933,8 @@ cdef class OldGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
check_status(self.inner.get().GetAllJobInfo(
c_optional_job_or_submission_id, timeout_ms, c_job_infos))
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, timeout_ms, c_job_infos))
for c_job_info in c_job_infos:
serialized_job_infos.push_back(c_job_info.SerializeAsString())
result = {}
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dashboard/modules/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ async def get_driver_jobs(
jobs with the job id or submission id.
"""
job_infos = await gcs_aio_client.get_all_job_info(
job_or_submission_id=job_or_submission_id, timeout=timeout
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=True,
skip_is_running_tasks_field=True,
timeout=timeout,
)
# Sort jobs from GCS to follow convention of returning only last driver
# of submission job.
Expand Down
9 changes: 7 additions & 2 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,15 @@ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil:
cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor":
CRayStatus GetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
c_vector[CJobTableData] &result,
int64_t timeout_ms)

CRayStatus AsyncGetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
const MultiItemPyCallback[CJobTableData] &callback,
int64_t timeout_ms)

Expand Down Expand Up @@ -625,8 +629,9 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
CRayStatus GetAllNodeInfo(
int64_t timeout_ms, c_vector[CGcsNodeInfo]& result)
CRayStatus GetAllJobInfo(
const optional[c_string] &job_or_submission_id, int64_t timeout_ms,
c_vector[CJobTableData]& result)
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field,
int64_t timeout_ms, c_vector[CJobTableData]& result)
CRayStatus GetAllResourceUsage(
int64_t timeout_ms, c_string& serialized_reply)
CRayStatus RequestClusterResourceConstraint(
Expand Down
13 changes: 10 additions & 3 deletions python/ray/includes/gcs_client.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ cdef class NewGcsClient:
#############################################################

def get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Dict[JobID, gcs_pb2.JobTableData]:
cdef c_string c_job_or_submission_id
Expand All @@ -449,11 +451,14 @@ cdef class NewGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
status = self.inner.get().Jobs().GetAll(
c_optional_job_or_submission_id, reply, timeout_ms)
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, reply, timeout_ms)
return raise_or_return((convert_get_all_job_info(status, move(reply))))

def async_get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Future[Dict[JobID, gcs_pb2.JobTableData]]:
cdef:
Expand All @@ -471,6 +476,8 @@ cdef class NewGcsClient:
check_status_timeout_as_rpc_error(
self.inner.get().Jobs().AsyncGetAll(
c_optional_job_or_submission_id,
c_skip_submission_job_info_field,
c_skip_is_running_tasks_field,
MultiItemPyCallback[CJobTableData](
&convert_get_all_job_info,
assign_and_decrement_fut,
Expand Down
2 changes: 2 additions & 0 deletions src/mock/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class MockJobInfoAccessor : public JobInfoAccessor {
MOCK_METHOD(Status,
AsyncGetAll,
(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms),
(override));
Expand Down
15 changes: 13 additions & 2 deletions src/ray/gcs/gcs_client/accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ Status JobInfoAccessor::AsyncSubscribeAll(
done(status);
}
};
RAY_CHECK_OK(
AsyncGetAll(/*job_or_submission_id=*/std::nullopt, callback, /*timeout_ms=*/-1));
RAY_CHECK_OK(AsyncGetAll(/*job_or_submission_id=*/std::nullopt,
/*skip_submission_job_info_field=*/true,
/*skip_is_running_tasks_field=*/true,
callback,
/*timeout_ms=*/-1));
};
subscribe_operation_ = [this, subscribe](const StatusCallback &done) {
return client_impl_->GetGcsSubscriber().SubscribeAllJobs(subscribe, done);
Expand All @@ -107,11 +110,15 @@ void JobInfoAccessor::AsyncResubscribe() {

Status JobInfoAccessor::AsyncGetAll(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms) {
RAY_LOG(DEBUG) << "Getting all job info.";
RAY_CHECK(callback);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand All @@ -126,9 +133,13 @@ Status JobInfoAccessor::AsyncGetAll(
}

Status JobInfoAccessor::GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms) {
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class JobInfoAccessor {
/// \param callback Callback that will be called after lookup finished.
/// \return Status
virtual Status AsyncGetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms);

Expand All @@ -272,6 +274,8 @@ class JobInfoAccessor {
/// \param timeout_ms -1 means infinite.
/// \return Status
virtual Status GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms);

Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,17 @@ Status PythonGcsClient::GetAllNodeInfo(int64_t timeout_ms,

Status PythonGcsClient::GetAllJobInfo(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result) {
grpc::ClientContext context;
PrepareContext(context, timeout_ms);

absl::ReaderMutexLock lock(&mutex_);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class RAY_EXPORT PythonGcsClient {
Status PinRuntimeEnvUri(const std::string &uri, int expiration_s, int64_t timeout_ms);
Status GetAllNodeInfo(int64_t timeout_ms, std::vector<rpc::GcsNodeInfo> &result);
Status GetAllJobInfo(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result);
Status GetAllResourceUsage(int64_t timeout_ms, std::string &serialized_reply);
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_client/global_state_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ std::vector<std::string> GlobalStateAccessor::GetAllJobInfo(
absl::ReaderMutexLock lock(&mutex_);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncGetAll(
/*job_or_submission_id=*/std::nullopt,
skip_submission_job_info_field,
skip_is_running_tasks_field,
TransformForMultiItemCallback<rpc::JobTableData>(job_table_data, promise),
/*timeout_ms=*/-1));
}
Expand Down
138 changes: 74 additions & 64 deletions src/ray/gcs/gcs_server/gcs_job_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,8 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
}
return false;
};

auto on_done = [this, filter_ok, reply, send_reply_callback, limit](
absl::flat_hash_map<JobID, JobTableData> &&result) {
auto on_done = [this, filter_ok, request, reply, send_reply_callback, limit](
const absl::flat_hash_map<JobID, JobTableData> &&result) {
// Internal KV keys for jobs that were submitted via the Ray Job API.
std::vector<std::string> job_api_data_keys;

Expand Down Expand Up @@ -298,77 +297,88 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
job_data_key_to_indices[job_data_key].push_back(i);
}

JobID job_id = data.first;
WorkerID worker_id = WorkerID::FromBinary(data.second.driver_address().worker_id());

// If job is not dead, get is_running_tasks from the core worker for the driver.
if (data.second.is_dead()) {
reply->mutable_job_info_list(i)->set_is_running_tasks(false);
core_worker_clients_.Disconnect(worker_id);
if (!request.skip_is_running_tasks_field()) {
JobID job_id = data.first;
WorkerID worker_id =
WorkerID::FromBinary(data.second.driver_address().worker_id());

// If job is not dead, get is_running_tasks from the core worker for the driver.
if (data.second.is_dead()) {
reply->mutable_job_info_list(i)->set_is_running_tasks(false);
core_worker_clients_.Disconnect(worker_id);
(*num_processed_jobs)++;
try_send_reply();
} else {
// Get is_running_tasks from the core worker for the driver.
auto client = core_worker_clients_.GetOrConnect(data.second.driver_address());
auto request = std::make_unique<rpc::NumPendingTasksRequest>();
constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000;
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id
<< ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms.";
client->NumPendingTasks(
std::move(request),
[job_id, worker_id, reply, i, num_processed_jobs, try_send_reply](
const Status &status,
const rpc::NumPendingTasksReply &num_pending_tasks_reply) {
RAY_LOG(DEBUG).WithField(worker_id)
<< "Received NumPendingTasksReply from worker.";
if (!status.ok()) {
RAY_LOG(WARNING).WithField(job_id).WithField(worker_id)
<< "Failed to get num_pending_tasks from core worker: " << status
<< ", is_running_tasks is unset.";
reply->mutable_job_info_list(i)->clear_is_running_tasks();
} else {
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
}
(*num_processed_jobs)++;
try_send_reply();
},
kNumPendingTasksRequestTimeoutMs);
}
} else {
(*num_processed_jobs)++;
try_send_reply();
} else {
// Get is_running_tasks from the core worker for the driver.
auto client = core_worker_clients_.GetOrConnect(data.second.driver_address());
auto request = std::make_unique<rpc::NumPendingTasksRequest>();
constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000;
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id
<< ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms.";
client->NumPendingTasks(
std::move(request),
[job_id, worker_id, reply, i, num_processed_jobs, try_send_reply](
const Status &status,
const rpc::NumPendingTasksReply &num_pending_tasks_reply) {
RAY_LOG(DEBUG).WithField(worker_id)
<< "Received NumPendingTasksReply from worker.";
if (!status.ok()) {
RAY_LOG(WARNING).WithField(job_id).WithField(worker_id)
<< "Failed to get num_pending_tasks from core worker: " << status
<< ", is_running_tasks is unset.";
reply->mutable_job_info_list(i)->clear_is_running_tasks();
} else {
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
}
(*num_processed_jobs)++;
try_send_reply();
},
kNumPendingTasksRequestTimeoutMs);
}
i++;
}

// Load the JobInfo for jobs submitted via the Ray Job API.
auto kv_multi_get_callback =
[reply,
send_reply_callback,
job_data_key_to_indices,
kv_callback_done,
try_send_reply](std::unordered_map<std::string, std::string> &&result) {
for (const auto &data : result) {
const std::string &job_data_key = data.first;
// The JobInfo stored by the Ray Job API.
const std::string &job_info_json = data.second;
if (!job_info_json.empty()) {
// Parse the JSON into a JobsAPIInfo proto.
rpc::JobsAPIInfo jobs_api_info;
auto status = google::protobuf::util::JsonStringToMessage(job_info_json,
&jobs_api_info);
if (!status.ok()) {
RAY_LOG(ERROR)
<< "Failed to parse JobInfo JSON into JobsAPIInfo protobuf. JSON: "
<< job_info_json << " Error: " << status.message();
}
// Add the JobInfo to the correct indices in the reply.
for (int i : job_data_key_to_indices.at(job_data_key)) {
reply->mutable_job_info_list(i)->mutable_job_info()->CopyFrom(
jobs_api_info);
if (!request.skip_submission_job_info_field()) {
// Load the JobInfo for jobs submitted via the Ray Job API.
auto kv_multi_get_callback =
[reply,
send_reply_callback,
job_data_key_to_indices,
kv_callback_done,
try_send_reply](std::unordered_map<std::string, std::string> &&result) {
for (const auto &data : result) {
const std::string &job_data_key = data.first;
// The JobInfo stored by the Ray Job API.
const std::string &job_info_json = data.second;
if (!job_info_json.empty()) {
// Parse the JSON into a JobsAPIInfo proto.
rpc::JobsAPIInfo jobs_api_info;
auto status = google::protobuf::util::JsonStringToMessage(job_info_json,
&jobs_api_info);
if (!status.ok()) {
RAY_LOG(ERROR)
<< "Failed to parse JobInfo JSON into JobsAPIInfo protobuf. JSON: "
<< job_info_json << " Error: " << status.message();
}
// Add the JobInfo to the correct indices in the reply.
for (int i : job_data_key_to_indices.at(job_data_key)) {
reply->mutable_job_info_list(i)->mutable_job_info()->CopyFrom(
jobs_api_info);
}
}
}
size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1;
try_send_reply(updated_finished_tasks);
*kv_callback_done = true;
try_send_reply();
};
internal_kv_.MultiGet("job", job_api_data_keys, kv_multi_get_callback);
} else {
*kv_callback_done = true;
try_send_reply();
}
};
Status status = gcs_table_storage_->JobTable().GetAll(on_done);
Expand Down
2 changes: 2 additions & 0 deletions src/ray/protobuf/gcs_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ message GetAllJobInfoRequest {
// If set, only return the job with that job id in hex, or the job with that
// job_submission_id.
optional string job_or_submission_id = 2;
optional bool skip_submission_job_info_field = 3;
optional bool skip_is_running_tasks_field = 4;
}

message GetAllJobInfoReply {
Expand Down

0 comments on commit 0e567a1

Please sign in to comment.