Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][state][dashboard] Use main threads's task id or actor creation task id for parent's task id in state API #32157

Merged
merged 5 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/src/ray/runtime/task/local_mode_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation,
required_resources,
required_placement_resources,
"",
/*depth=*/0);
/*depth=*/0,
local_mode_ray_tuntime_.GetCurrentTaskId());
if (invocation.task_type == TaskType::NORMAL_TASK) {
} else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) {
invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID();
Expand Down
206 changes: 206 additions & 0 deletions python/ray/tests/test_task_events.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import defaultdict
from typing import Dict
import pytest
import threading
import time

import ray
from ray.experimental.state.common import ListApiOptions, StateResource
from ray._private.test_utils import (
raw_metrics,
run_string_as_driver,
run_string_as_driver_nonblocking,
wait_for_condition,
)
Expand Down Expand Up @@ -124,6 +126,210 @@ def verify():
)


def test_parent_task_id_threaded_task(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)

# Task starts a thread
@ray.remote
def main_task():
def thd_task():
@ray.remote
def thd_task():
pass

ray.get(thd_task.remote())

thd = threading.Thread(target=thd_task)
thd.start()
thd.join()

ray.get(main_task.remote())

def verify():
tasks = list_tasks()
assert len(tasks) == 2
expect_parent_task_id = None
actual_parent_task_id = None
for task in tasks:
if task["name"] == "main_task":
expect_parent_task_id = task["task_id"]
elif task["name"] == "thd_task":
actual_parent_task_id = task["parent_task_id"]
assert actual_parent_task_id is not None
assert expect_parent_task_id == actual_parent_task_id

return True

wait_for_condition(verify)


def test_parent_task_id_actor(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)

def run_task_in_thread():
def thd_task():
@ray.remote
def thd_task():
pass

ray.get(thd_task.remote())

thd = threading.Thread(target=thd_task)
thd.start()
thd.join()

@ray.remote
class Actor:
def main_task(self):
run_task_in_thread()

a = Actor.remote()
ray.get(a.main_task.remote())

def verify():
tasks = list_tasks()
expect_parent_task_id = None
actual_parent_task_id = None
for task in tasks:
if "main_task" in task["name"]:
expect_parent_task_id = task["task_id"]
elif "thd_task" in task["name"]:
actual_parent_task_id = task["parent_task_id"]
print(tasks)
assert actual_parent_task_id is not None
assert expect_parent_task_id == actual_parent_task_id

return True

wait_for_condition(verify)


@pytest.mark.parametrize("actor_concurrency", [3, 10])
def test_parent_task_id_multi_thread_actors(shutdown_only, actor_concurrency):
ray.init(_system_config=_SYSTEM_CONFIG)

def run_task_in_thread(name, i):
def thd_task():
@ray.remote
def thd_task():
pass

ray.get(thd_task.options(name=f"{name}_{i}").remote())

thd = threading.Thread(target=thd_task)
thd.start()
thd.join()

@ray.remote
class AsyncActor:
async def main_task(self, i):
run_task_in_thread("async_thd_task", i)

@ray.remote
class ThreadedActor:
def main_task(self, i):
run_task_in_thread("threaded_thd_task", i)

def verify(actor_method_name, actor_class_name):
tasks = list_tasks()
print(tasks)
expect_parent_task_id = None
actual_parent_task_id = None
for task in tasks:
if f"{actor_class_name}.__init__" in task["name"]:
expect_parent_task_id = task["task_id"]

assert expect_parent_task_id is not None
for task in tasks:
if f"{actor_method_name}" in task["name"]:
actual_parent_task_id = task["parent_task_id"]
assert expect_parent_task_id == actual_parent_task_id, task

return True

async_actor = AsyncActor.options(max_concurrency=actor_concurrency).remote()
ray.get([async_actor.main_task.remote(i) for i in range(20)])
wait_for_condition(
verify, actor_class_name="AsyncActor", actor_method_name="async_thd_task"
)

thd_actor = ThreadedActor.options(max_concurrency=actor_concurrency).remote()
ray.get([thd_actor.main_task.remote(i) for i in range(20)])
wait_for_condition(
verify, actor_class_name="ThreadedActor", actor_method_name="threaded_thd_task"
)


def test_parent_task_id_tune_e2e(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)
job_id = ray.get_runtime_context().get_job_id()
script = """
import numpy as np
import ray
from ray import tune
import time

ray.init("auto")

@ray.remote
def train_step_1():
time.sleep(0.5)
return 1

def train_function(config):
for i in range(5):
loss = config["mean"] * np.random.randn() + ray.get(
train_step_1.remote())
tune.report(loss=loss, nodes=ray.nodes())


def tune_function():
analysis = tune.run(
train_function,
metric="loss",
mode="min",
config={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
},
resources_per_trial=tune.PlacementGroupFactory([{
'CPU': 1.0
}] + [{
'CPU': 1.0
}] * 3),
)
return analysis.best_config


tune_function()
"""

run_string_as_driver(script)
client = StateApiClient()

def list_tasks():
return client.list(
StateResource.TASKS,
# Filter out this driver
options=ListApiOptions(
exclude_driver=False, filters=[("job_id", "!=", job_id)], limit=1000
),
raise_on_missing_output=True,
)

def verify():
tasks = list_tasks()

task_id_map = {task["task_id"]: task for task in tasks}
for task in tasks:
if task["type"] == "DRIVER_TASK":
continue
assert task_id_map.get(task["parent_task_id"], None) is not None, task

return True

wait_for_condition(verify)


def test_handle_driver_tasks(shutdown_only):
ray.init(_system_config=_SYSTEM_CONFIG)

Expand Down
7 changes: 7 additions & 0 deletions src/ray/common/task/task_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ TaskID TaskSpecification::ParentTaskId() const {
return TaskID::FromBinary(message_->parent_task_id());
}

TaskID TaskSpecification::MainThreadParentTaskId() const {
if (message_->main_thread_parent_task_id().empty() /* e.g., empty proto default */) {
return TaskID::Nil();
}
return TaskID::FromBinary(message_->main_thread_parent_task_id());
}

size_t TaskSpecification::ParentCounter() const { return message_->parent_counter(); }

ray::FunctionDescriptor TaskSpecification::FunctionDescriptor() const {
Expand Down
2 changes: 2 additions & 0 deletions src/ray/common/task/task_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {

TaskID ParentTaskId() const;

TaskID MainThreadParentTaskId() const;

size_t ParentCounter() const;

ray::FunctionDescriptor FunctionDescriptor() const;
Expand Down
6 changes: 5 additions & 1 deletion src/ray/common/task/task_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class TaskSpecBuilder {
const std::unordered_map<std::string, double> &required_placement_resources,
const std::string &debugger_breakpoint,
int64_t depth,
const TaskID &main_thread_parent_task_id,
const std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr,
const std::string &concurrency_group_name = "") {
message_->set_type(TaskType::NORMAL_TASK);
Expand All @@ -142,6 +143,7 @@ class TaskSpecBuilder {
}
message_->set_task_id(task_id.Binary());
message_->set_parent_task_id(parent_task_id.Binary());
message_->set_main_thread_parent_task_id(main_thread_parent_task_id.Binary());
message_->set_parent_counter(parent_counter);
message_->set_caller_id(caller_id.Binary());
message_->mutable_caller_address()->CopyFrom(caller_address);
Expand Down Expand Up @@ -182,12 +184,14 @@ class TaskSpecBuilder {
const JobID &job_id,
const TaskID &parent_task_id,
const TaskID &caller_id,
const rpc::Address &caller_address) {
const rpc::Address &caller_address,
const TaskID &main_thread_parent_task_id) {
message_->set_type(TaskType::DRIVER_TASK);
message_->set_language(language);
message_->set_job_id(job_id.Binary());
message_->set_task_id(task_id.Binary());
message_->set_parent_task_id(parent_task_id.Binary());
message_->set_main_thread_parent_task_id(main_thread_parent_task_id.Binary());
message_->set_parent_counter(0);
message_->set_caller_id(caller_id.Binary());
message_->mutable_caller_address()->CopyFrom(caller_address);
Expand Down
10 changes: 10 additions & 0 deletions src/ray/core_worker/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ WorkerContext::WorkerContext(WorkerType worker_type,
RAY_CHECK(!current_job_id_.IsNil());
GetThreadContext().SetCurrentTaskId(TaskID::ForDriverTask(job_id),
/*attempt_number=*/0);
// Driver runs in the main thread.
main_thread_current_task_id_ = TaskID::ForDriverTask(job_id);
}
}

Expand Down Expand Up @@ -266,6 +268,9 @@ void WorkerContext::SetTaskDepth(int64_t depth) { task_depth_ = depth; }
void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
GetThreadContext().SetCurrentTask(task_spec);
absl::WriterMutexLock lock(&mutex_);
if (CurrentThreadIsMain()) {
main_thread_current_task_id_ = task_spec.TaskId();
}
SetTaskDepth(task_spec.GetDepth());
RAY_CHECK(current_job_id_ == task_spec.JobId());
if (task_spec.IsNormalTask()) {
Expand Down Expand Up @@ -314,6 +319,11 @@ bool WorkerContext::CurrentThreadIsMain() const {
return boost::this_thread::get_id() == main_thread_id_;
}

const TaskID WorkerContext::GetMainThreadCurrentTaskID() const {
absl::ReaderMutexLock lock(&mutex_);
return main_thread_current_task_id_;
}

bool WorkerContext::ShouldReleaseResourcesOnBlockingCalls() const {
// Check if we need to release resources when we block:
// - Driver doesn't acquire resources and thus doesn't need to release.
Expand Down
5 changes: 5 additions & 0 deletions src/ray/core_worker/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class WorkerContext {

const TaskID &GetCurrentTaskID() const;

const TaskID GetMainThreadCurrentTaskID() const;

const PlacementGroupID &GetCurrentPlacementGroupId() const LOCKS_EXCLUDED(mutex_);

bool ShouldCaptureChildTasksInPlacementGroup() const LOCKS_EXCLUDED(mutex_);
Expand Down Expand Up @@ -130,6 +132,9 @@ class WorkerContext {
std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info_ GUARDED_BY(mutex_);
/// The id of the (main) thread that constructed this worker context.
const boost::thread::id main_thread_id_;
/// The currently executing main thread's task id. Used merely for observability
/// purposes to track task hierarchy.
TaskID main_thread_current_task_id_;
// To protect access to mutable members;
mutable absl::Mutex mutex_;

Expand Down
8 changes: 7 additions & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
// Driver has no parent task
/* parent_task_id */ TaskID::Nil(),
GetCallerId(),
rpc_address_);
rpc_address_,
TaskID::Nil());
// Drivers are never re-executed.
SetCurrentTaskId(task_id, /*attempt_number=*/0, "driver");

Expand Down Expand Up @@ -1748,6 +1749,7 @@ void CoreWorker::BuildCommonTaskSpec(
const std::string &debugger_breakpoint,
int64_t depth,
const std::string &serialized_runtime_env_info,
const TaskID &main_thread_current_task_id,
const std::string &concurrency_group_name,
bool include_job_config) {
// Build common task spec.
Expand Down Expand Up @@ -1780,6 +1782,7 @@ void CoreWorker::BuildCommonTaskSpec(
required_placement_resources,
debugger_breakpoint,
depth,
main_thread_current_task_id,
override_runtime_env_info,
concurrency_group_name);
// Set task arguments.
Expand Down Expand Up @@ -1830,6 +1833,7 @@ std::vector<rpc::ObjectReference> CoreWorker::SubmitTask(
debugger_breakpoint,
depth,
task_options.serialized_runtime_env_info,
worker_context_.GetMainThreadCurrentTaskID(),
/*concurrency_group_name*/ "",
/*include_job_config*/ true);
builder.SetNormalTaskSpec(max_retries,
Expand Down Expand Up @@ -1913,6 +1917,7 @@ Status CoreWorker::CreateActor(const RayFunction &function,
"" /* debugger_breakpoint */,
depth,
actor_creation_options.serialized_runtime_env_info,
worker_context_.GetMainThreadCurrentTaskID(),
/*concurrency_group_name*/ "",
/*include_job_config*/ true);

Expand Down Expand Up @@ -2137,6 +2142,7 @@ std::optional<std::vector<rpc::ObjectReference>> CoreWorker::SubmitActorTask(
"", /* debugger_breakpoint */
depth, /*depth*/
"{}", /* serialized_runtime_env_info */
worker_context_.GetMainThreadCurrentTaskID(),
task_options.concurrency_group_name,
/*include_job_config*/ false);
// NOTE: placement_group_capture_child_tasks and runtime_env will
Expand Down
1 change: 1 addition & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
const std::string &debugger_breakpoint,
int64_t depth,
const std::string &serialized_runtime_env_info,
const TaskID &main_thread_current_task_id,
const std::string &concurrency_group_name = "",
bool include_job_config = false);
void SetCurrentTaskId(const TaskID &task_id,
Expand Down
Loading