-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support inlining task arguments #72
Changes from 20 commits
30d2abe
1dd4e56
57cb3fa
f36b06a
5185e71
c8e7875
f572822
1f40349
feb71cf
a0852b0
a48db48
9e27f4c
35b7e4c
68b5372
2344b1f
026b152
b5738cb
2727fd3
48b1b65
7623be0
64c71cb
327dd2b
cf8da68
8eec7a3
c66bea3
399c917
987e22a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -173,23 +173,76 @@ end | |
|
||
initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...) | ||
|
||
# TODO: Move task related code into a "task.jl" file | ||
function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); | ||
runtime_env::Union{RuntimeEnv,Nothing}=nothing, | ||
resources::Dict{String,Float64}=Dict("CPU" => 1.0)) | ||
export_function!(FUNCTION_MANAGER[], f, get_current_job_id()) | ||
fd = ray_jll.function_descriptor(f) | ||
arg_oids = map(Ray.put, flatten_args(args, kwargs)) | ||
task_args = serialize_args(flatten_args(args, kwargs)) | ||
|
||
serialized_runtime_env_info = if !isnothing(runtime_env) | ||
_serialize(RuntimeEnvInfo(runtime_env)) | ||
else | ||
"" | ||
end | ||
|
||
return GC.@preserve args ray_jll._submit_task(fd, | ||
arg_oids, | ||
serialized_runtime_env_info, | ||
resources) | ||
GC.@preserve task_args begin | ||
return ray_jll._submit_task(fd, | ||
transform_task_args(task_args), | ||
serialized_runtime_env_info, | ||
resources) | ||
end | ||
end | ||
|
||
# Adapted from `prepare_args_internal`: | ||
# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L673 | ||
function serialize_args(args) | ||
ray_config = ray_jll.RayConfigInstance() | ||
put_threshold = ray_jll.max_direct_call_object_size(ray_config) | ||
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config) | ||
record_call_site = ray_jll.record_ref_creation_sites(ray_config) | ||
|
||
rpc_address = ray_jll.GetRpcAddress() | ||
|
||
total_inlined = 0 | ||
task_args = Any[] | ||
for arg in args | ||
# Note: The Python `prepare_args_internal` function checks if the `arg` is an | ||
# `ObjectRef` and in that case uses the object ID to directly make a | ||
# `TaskArgByReference`. However, as the `args` here are flattened the `arg` will | ||
# always be a `Pair` (or a list in Python). I suspect this Python code path just | ||
# dead code so we'll exclude it from ours. | ||
|
||
serialized_arg = serialize_to_bytes(arg) | ||
serialized_arg_size = sizeof(serialized_arg) | ||
buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true) | ||
|
||
# Inline arguments which are small and if there is room | ||
task_arg = if (serialized_arg_size <= put_threshold && | ||
serialized_arg_size + total_inlined <= rpc_inline_threshold) | ||
omus marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
total_inlined += serialized_arg_size | ||
ray_jll.TaskArgByValue(ray_jll.RayObject(buffer)) | ||
else | ||
oid = ray_jll.put(buffer) | ||
# TODO: Add test for populating `call_site` | ||
call_site = record_call_site ? sprint(Base.show_backtrace, backtrace()) : "" | ||
ray_jll.TaskArgByReference(oid, rpc_address, call_site) | ||
end | ||
|
||
push!(task_args, task_arg) | ||
end | ||
|
||
return task_args | ||
end | ||
|
||
function transform_task_args(task_args) | ||
task_arg_ptrs = StdVector{CxxPtr{ray_jll.TaskArg}}() | ||
for task_arg in task_args | ||
push!(task_arg_ptrs, CxxRef(task_arg)) | ||
omus marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
return task_arg_ptrs | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One liner doesn't currently work: julia> using Ray, CxxWrap
julia> Ray.init()
julia> task_args = Ray.serialize_args([1,2,3])
3-element Vector{Any}:
ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000042100)
ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000043be0)
ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000042560)
julia> CxxRef.(convert.(ray_jll.TaskArg, task_args))
3-element Vector{CxxRef{ray_core_worker_julia_jll.TaskArg}}:
CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x0000600000051300)
CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x000060000004de60)
CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x000060000004d0e0)
julia> StdVector{CxxPtr{ray_jll.TaskArg}}(CxxRef.(convert.(ray_jll.TaskArg, task_args)))
ERROR: MethodError: no method matching StdVector{CxxPtr{ray_core_worker_julia_jll.TaskArg}}(::Vector{CxxRef{ray_core_worker_julia_jll.TaskArg}})
Closest candidates are:
StdVector{CxxPtr{ray_core_worker_julia_jll.TaskArg}}()
@ ray_core_worker_julia_jll ~/.julia/packages/CxxWrap/aXNBY/src/CxxWrap.jl:624
Stacktrace:
[1] top-level scope
@ REPL[15]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, | ||
|
@@ -248,7 +301,9 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, | |
|
||
# TODO: support multiple return values | ||
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54 | ||
push!(returns, to_serialized_buffer(result)) | ||
bytes = serialize_to_bytes(result) | ||
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) | ||
push!(returns, buffer) | ||
|
||
return nothing | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
using Aqua | ||
using CxxWrap: CxxPtr, StdVector | ||
using Ray | ||
using Serialization | ||
using Test | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,7 @@ std::vector<std::shared_ptr<RayObject>> cast_to_task_args(void *ptr) { | |
} | ||
|
||
ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, | ||
const std::vector<ObjectID> &object_ids, | ||
const std::vector<TaskArg *> &task_args, | ||
const std::string &serialized_runtime_env_info, | ||
const std::unordered_map<std::string, double> &resources) { | ||
|
||
|
@@ -149,9 +149,11 @@ ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, | |
ray::FunctionDescriptor func_descriptor = std::make_shared<ray::JuliaFunctionDescriptor>(jl_func_descriptor); | ||
RayFunction func(Language::JULIA, func_descriptor); | ||
|
||
// TODO: Passing in a `std::vector<std::unique_ptr<TaskArg>>` from Julia may currently be impossible due to: | ||
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370 | ||
std::vector<std::unique_ptr<TaskArg>> args; | ||
for (auto & obj_id : object_ids) { | ||
args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/"")); | ||
for (auto &task_arg : task_args) { | ||
args.emplace_back(task_arg); | ||
Comment on lines
+155
to
+156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since the element type is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah wait I always forget, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Just to avoid any possible confusion the |
||
} | ||
|
||
// TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87 | ||
|
@@ -189,6 +191,11 @@ TaskID GetCurrentTaskId() { | |
return worker.GetCurrentTaskId(); | ||
} | ||
|
||
const rpc::Address &GetRpcAddress() { | ||
auto &worker = CoreWorkerProcess::GetCoreWorker(); | ||
return worker.GetRpcAddress(); | ||
} | ||
|
||
// https://github.com/ray-project/ray/blob/a4a8389a3053b9ef0e8409a55e2fae618bfca2be/src/ray/core_worker/test/core_worker_test.cc#L224-L237 | ||
ObjectID put(std::shared_ptr<Buffer> buffer) { | ||
auto &driver = CoreWorkerProcess::GetCoreWorker(); | ||
|
@@ -364,16 +371,24 @@ std::unordered_map<std::string, double> get_task_required_resources() { | |
return worker_context.GetCurrentTask()->GetRequiredResources().GetResourceUnorderedMap(); | ||
} | ||
|
||
void _push_back(std::vector<TaskArg *> &vector, TaskArg &el) { | ||
vector.push_back(&el); | ||
} | ||
|
||
namespace jlcxx | ||
{ | ||
// Needed for upcasting | ||
template<> struct SuperType<LocalMemoryBuffer> { typedef Buffer type; }; | ||
template<> struct SuperType<JuliaFunctionDescriptor> { typedef FunctionDescriptorInterface type; }; | ||
template<> struct SuperType<TaskArgByReference> { typedef TaskArg type; }; | ||
template<> struct SuperType<TaskArgByValue> { typedef TaskArg type; }; | ||
|
||
// Disable generated constructors | ||
// https://github.com/JuliaInterop/CxxWrap.jl/issues/141#issuecomment-491373720 | ||
template<> struct DefaultConstructible<LocalMemoryBuffer> : std::false_type {}; | ||
template<> struct DefaultConstructible<RayObject> : std::false_type {}; | ||
// template<> struct DefaultConstructible<JuliaFunctionDescriptor> : std::false_type {}; | ||
template<> struct DefaultConstructible<TaskArg> : std::false_type {}; | ||
|
||
// Custom finalizer to show what is being deleted. Can be useful in tracking down | ||
// segmentation faults due to double deallocations | ||
|
@@ -503,14 +518,45 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) | |
|
||
mod.method("put", &put); | ||
mod.method("get", &get); | ||
mod.method("_submit_task", &_submit_task); | ||
|
||
// class ObjectReference | ||
// message Address | ||
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L86 | ||
mod.add_type<rpc::Address>("Address") | ||
.constructor<>() | ||
.method("SerializeToString", [](const rpc::Address &addr) { | ||
std::string serialized; | ||
addr.SerializeToString(&serialized); | ||
return serialized; | ||
}) | ||
.method("MessageToJsonString", [](const rpc::Address &addr) { | ||
std::string json; | ||
google::protobuf::util::MessageToJsonString(addr, &json); | ||
return json; | ||
}); | ||
mod.method("GetRpcAddress", &GetRpcAddress); | ||
|
||
// message ObjectReference | ||
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L500 | ||
mod.add_type<rpc::ObjectReference>("ObjectReference"); | ||
jlcxx::stl::apply_stl<rpc::ObjectReference>(mod); | ||
|
||
// class RayObject | ||
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_object.h#L28 | ||
mod.add_type<RayObject>("RayObject") | ||
.method("GetData", &RayObject::GetData); | ||
|
||
// Julia RayObject constructors make shared_ptrs | ||
mod.method("RayObject", [] ( | ||
const std::shared_ptr<Buffer> &data, | ||
const std::shared_ptr<Buffer> &metadata, | ||
const std::vector<rpc::ObjectReference> &nested_refs, | ||
bool copy_data = false) { | ||
|
||
return std::make_shared<RayObject>(data, metadata, nested_refs, copy_data); | ||
}); | ||
mod.method("RayObject", [] (const std::shared_ptr<Buffer> &data) { | ||
return std::make_shared<RayObject>(data, nullptr, std::vector<rpc::ObjectReference>(), false); | ||
}); | ||
jlcxx::stl::apply_stl<std::shared_ptr<RayObject>>(mod); | ||
|
||
mod.add_type<Status>("Status") | ||
|
@@ -541,4 +587,39 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) | |
mod.method("serialize_job_config_json", &serialize_job_config_json); | ||
mod.method("get_job_serialized_runtime_env", &get_job_serialized_runtime_env); | ||
mod.method("get_task_required_resources", &get_task_required_resources); | ||
|
||
mod.add_type<RayConfig>("RayConfig") | ||
.method("RayConfigInstance", &RayConfig::instance) | ||
.method("max_direct_call_object_size", [](RayConfig &config) { | ||
return config.max_direct_call_object_size(); | ||
}) | ||
.method("task_rpc_inlined_bytes_limit", [](RayConfig &config) { | ||
return config.task_rpc_inlined_bytes_limit(); | ||
}) | ||
.method("record_ref_creation_sites", [](RayConfig &config) { | ||
return config.record_ref_creation_sites(); | ||
}); | ||
Comment on lines
+587
to
+595
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do these need to be lambdas? seem pretty straightforward... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure now. I'll try this without the lambdas There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fails with this:
|
||
|
||
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/task/task_util.h | ||
mod.add_type<TaskArg>("TaskArg"); | ||
mod.method("_push_back", &_push_back); | ||
|
||
// The Julia types `TaskArgByReference` and `TaskArgByValue` have their default finalizers | ||
// disabled as these will later be used as `std::unique_ptr`. If these finalizers were enabled | ||
// we would see segmentation faults due to double deallocations. | ||
// | ||
// Note: It is possible to create `std::unique_ptr`s in C++ and return them to Julia however | ||
// CxxWrap is unable to compile any wrapped functions using `std::vector<std::unique_ptr<TaskArg>>`. | ||
// We're working around this by using `std::vector<TaskArg *>`. | ||
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370 | ||
|
||
mod.add_type<TaskArgByReference>("TaskArgByReference", jlcxx::julia_base_type<TaskArg>()) | ||
.constructor<const ObjectID &/*object_id*/, | ||
const rpc::Address &/*owner_address*/, | ||
const std::string &/*call_site*/>(false); | ||
|
||
mod.add_type<TaskArgByValue>("TaskArgByValue", jlcxx::julia_base_type<TaskArg>()) | ||
.constructor<const std::shared_ptr<RayObject> &/*value*/>(false); | ||
|
||
mod.method("_submit_task", &_submit_task); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,12 +148,17 @@ function Base.push!(v::CxxPtr{StdVector{T}}, el::T) where T <: SharedPtr{LocalMe | |
return push!(v, CxxRef(el)) | ||
end | ||
|
||
# Work around CxxWrap's `push!` always dereferencing our value via `@cxxdereference` | ||
# https://github.com/JuliaInterop/CxxWrap.jl/blob/0de5fbc5673367adc7e725cfc6e1fc6a8f9240a0/src/StdLib.jl#L78-L81 | ||
function Base.push!(v::StdVector{CxxPtr{TaskArg}}, el::CxxRef{<:TaskArg}) | ||
_push_back(v, el) | ||
return v | ||
end | ||
|
||
# XXX: Need to convert julia vectors to StdVector and build the | ||
# `std::unordered_map` for resources. This function helps us avoid having | ||
# CxxWrap as a direct dependency in Ray.jl | ||
function _submit_task(fd, oids::AbstractVector, serialized_runtime_env_info, resources) | ||
# https://github.com/JuliaInterop/CxxWrap.jl/issues/367 | ||
args = isempty(oids) ? StdVector{ObjectID}() : StdVector(oids) | ||
function _submit_task(fd, args, serialized_runtime_env_info, resources::AbstractDict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason to remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
@debug "task resources: " resources | ||
resources = build_resource_requests(resources) | ||
return _submit_task(fd, args, serialized_runtime_env_info, resources) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be doing this though? the values will be automatically dereferenced on the the task execution side so I'm not so sure; then again, if we don't do it then users may be surprised if they do have to de-reference in their work functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Python code I'm referring to is attempting to avoid creating a new
ObjectRef
here if the user has passed in one as an argument. In all cases Python will dereference the arguments inside the called task. What I was specifically calling out here is that due toflatten_args
this optimization will never occur as anarg
is never aObjectRef
. At best it would be alist
containing anObjectRef
as the second value.