Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inlining task arguments #72

Merged
merged 27 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Ray.jl/src/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Ray

using ArgParse
using Base64
using CxxWrap: StdVector, CxxPtr, CxxRef
using CxxWrap.StdLib: SharedPtr
using JSON3
using Logging
Expand Down
24 changes: 14 additions & 10 deletions Ray.jl/src/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
Store `data` in the object store. Returns an object reference which can used to retrieve
the `data` with [`Ray.get`](@ref).
"""
put(data) = ray_jll.put(to_serialized_buffer(data))
function put(data)
bytes = serialize_to_bytes(data)
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
return ray_jll.put(buffer)
end

"""
Ray.get(object_id::ObjectIDAllocated)
Expand All @@ -16,24 +20,24 @@ if run in an `@async` task.
If the task that generated the `ObjectID` failed with a Julia exception, the
captured exception will be thrown on `get`.
"""
get(oid::ray_jll.ObjectIDAllocated) = _get(ray_jll.get(oid))
get(obj::SharedPtr{ray_jll.RayObject}) = _get(ray_jll.GetData(obj[]))
get(oid::ray_jll.ObjectIDAllocated) = _get(take!(ray_jll.get(oid)))
get(obj::SharedPtr{ray_jll.RayObject}) = _get(take!(ray_jll.GetData(obj[])))
get(x) = x

function _get(buffer)
result = from_serialized_buffer(buffer)
function _get(bytes)
result = deserialize_from_bytes(bytes)
# TODO: add an option to not rethrow
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/58
result isa RayRemoteException ? throw(result) : return result
end

function to_serialized_buffer(data)
function serialize_to_bytes(x)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
serialize(io, data)
return ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
serialize(io, x)
return bytes
end

function from_serialized_buffer(buffer)
result = deserialize(IOBuffer(take!(buffer)))
function deserialize_from_bytes(bytes)
return deserialize(IOBuffer(bytes))
end
67 changes: 61 additions & 6 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,23 +173,76 @@ end

initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...)

# TODO: Move task related code into a "task.jl" file
function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
runtime_env::Union{RuntimeEnv,Nothing}=nothing,
resources::Dict{String,Float64}=Dict("CPU" => 1.0))
export_function!(FUNCTION_MANAGER[], f, get_current_job_id())
fd = ray_jll.function_descriptor(f)
arg_oids = map(Ray.put, flatten_args(args, kwargs))
task_args = serialize_args(flatten_args(args, kwargs))

serialized_runtime_env_info = if !isnothing(runtime_env)
_serialize(RuntimeEnvInfo(runtime_env))
else
""
end

return GC.@preserve args ray_jll._submit_task(fd,
arg_oids,
serialized_runtime_env_info,
resources)
GC.@preserve task_args begin
return ray_jll._submit_task(fd,
transform_task_args(task_args),
serialized_runtime_env_info,
resources)
end
end

# Adapted from `prepare_args_internal`:
# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L673
function serialize_args(args)
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)
record_call_site = ray_jll.record_ref_creation_sites(ray_config)

rpc_address = ray_jll.GetRpcAddress()

total_inlined = 0
task_args = Any[]
for arg in args
# Note: The Python `prepare_args_internal` function checks if the `arg` is an
# `ObjectRef` and in that case uses the object ID to directly make a
# `TaskArgByReference`. However, as the `args` here are flattened the `arg` will
# always be a `Pair` (or a list in Python). I suspect this Python code path just
# dead code so we'll exclude it from ours.
Comment on lines +222 to +226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be doing this though? the values will be automatically dereferenced on the the task execution side so I'm not so sure; then again, if we don't do it then users may be surprised if they do have to de-reference in their work functions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python code I'm referring to is attempting to avoid creating a new ObjectRef here if the user has passed in one as an argument. In all cases Python will dereference the arguments inside the called task. What I was specifically calling out here is that due to flatten_args this optimization will never occur as an arg is never a ObjectRef. At best it would be a list containing an ObjectRef as the second value.


serialized_arg = serialize_to_bytes(arg)
serialized_arg_size = sizeof(serialized_arg)
buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true)

# Inline arguments which are small and if there is room
task_arg = if (serialized_arg_size <= put_threshold &&
serialized_arg_size + total_inlined <= rpc_inline_threshold)
omus marked this conversation as resolved.
Show resolved Hide resolved

total_inlined += serialized_arg_size
ray_jll.TaskArgByValue(ray_jll.RayObject(buffer))
else
oid = ray_jll.put(buffer)
# TODO: Add test for populating `call_site`
call_site = record_call_site ? sprint(Base.show_backtrace, backtrace()) : ""
ray_jll.TaskArgByReference(oid, rpc_address, call_site)
end

push!(task_args, task_arg)
end

return task_args
end

function transform_task_args(task_args)
task_arg_ptrs = StdVector{CxxPtr{ray_jll.TaskArg}}()
for task_arg in task_args
push!(task_arg_ptrs, CxxRef(task_arg))
omus marked this conversation as resolved.
Show resolved Hide resolved
end
return task_arg_ptrs
end
Copy link
Member Author

@omus omus Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One liner doesn't currently work:

julia> using Ray, CxxWrap

julia> Ray.init()

julia> task_args = Ray.serialize_args([1,2,3])
3-element Vector{Any}:
 ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000042100)
 ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000043be0)
 ray_core_worker_julia_jll.TaskArgByValueAllocated(Ptr{Nothing} @0x0000600000042560)

julia> CxxRef.(convert.(ray_jll.TaskArg, task_args))
3-element Vector{CxxRef{ray_core_worker_julia_jll.TaskArg}}:
 CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x0000600000051300)
 CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x000060000004de60)
 CxxRef{ray_core_worker_julia_jll.TaskArg}(Ptr{ray_core_worker_julia_jll.TaskArg} @0x000060000004d0e0)

julia> StdVector{CxxPtr{ray_jll.TaskArg}}(CxxRef.(convert.(ray_jll.TaskArg, task_args)))
ERROR: MethodError: no method matching StdVector{CxxPtr{ray_core_worker_julia_jll.TaskArg}}(::Vector{CxxRef{ray_core_worker_julia_jll.TaskArg}})

Closest candidates are:
  StdVector{CxxPtr{ray_core_worker_julia_jll.TaskArg}}()
   @ ray_core_worker_julia_jll ~/.julia/packages/CxxWrap/aXNBY/src/CxxWrap.jl:624

Stacktrace:
 [1] top-level scope
   @ REPL[15]:1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#80


function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
Expand Down Expand Up @@ -248,7 +301,9 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,

# TODO: support multiple return values
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54
push!(returns, to_serialized_buffer(result))
bytes = serialize_to_bytes(result)
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
push!(returns, buffer)

return nothing
end
Expand Down
1 change: 1 addition & 0 deletions Ray.jl/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Aqua
using CxxWrap: CxxPtr, StdVector
using Ray
using Serialization
using Test
Expand Down
36 changes: 36 additions & 0 deletions Ray.jl/test/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,41 @@
@test !contains(stderr_logs, "Constructing CoreWorkerProcess")
end
end
end

@testset "serialize_args" begin
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)
serialization_overhead = begin
sizeof(Ray.serialize_to_bytes(:_ => zeros(UInt8, put_threshold))) - put_threshold
end

@testset "put threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
b = :_ => zeros(UInt8, put_threshold - serialization_overhead + 1)
omus marked this conversation as resolved.
Show resolved Hide resolved
task_args = Ray.serialize_args([a, b])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByValue
@test task_args[2] isa ray_jll.TaskArgByReference

task_args = Ray.serialize_args([b, a])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByReference
@test task_args[2] isa ray_jll.TaskArgByValue
end

@testset "inline threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
args = fill(a, rpc_inline_threshold ÷ put_threshold + 1)
task_args = Ray.serialize_args(args)
@test all(t -> t isa ray_jll.TaskArgByValue, task_args[1:(end - 1)])
@test task_args[end] isa ray_jll.TaskArgByReference
end
end

@testset "transform_task_args" begin
task_args = Ray.serialize_args(Ray.flatten_args([1, 2, 3], (;)))
result = Ray.transform_task_args(task_args)
@test result isa StdVector{CxxPtr{ray_jll.TaskArg}}
end
91 changes: 86 additions & 5 deletions deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ std::vector<std::shared_ptr<RayObject>> cast_to_task_args(void *ptr) {
}

ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
const std::vector<ObjectID> &object_ids,
const std::vector<TaskArg *> &task_args,
const std::string &serialized_runtime_env_info,
const std::unordered_map<std::string, double> &resources) {

Expand All @@ -149,9 +149,11 @@ ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
ray::FunctionDescriptor func_descriptor = std::make_shared<ray::JuliaFunctionDescriptor>(jl_func_descriptor);
RayFunction func(Language::JULIA, func_descriptor);

// TODO: Passing in a `std::vector<std::unique_ptr<TaskArg>>` from Julia may currently be impossible due to:
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370
std::vector<std::unique_ptr<TaskArg>> args;
for (auto & obj_id : object_ids) {
args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/""));
for (auto &task_arg : task_args) {
args.emplace_back(task_arg);
Comment on lines +155 to +156
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the element type is unique_ptr already I'm assuming this will auto-cast to unique_ptr?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah wait I always forget, emplace_back constructs and pushes...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the element type is unique_ptr already I'm assuming this will auto-cast to unique_ptr?

Just to avoid any possible confusion the task_arg element is just a raw pointer and is not already a unique_ptr

}

// TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87
Expand Down Expand Up @@ -189,6 +191,11 @@ TaskID GetCurrentTaskId() {
return worker.GetCurrentTaskId();
}

const rpc::Address &GetRpcAddress() {
auto &worker = CoreWorkerProcess::GetCoreWorker();
return worker.GetRpcAddress();
}

// https://github.com/ray-project/ray/blob/a4a8389a3053b9ef0e8409a55e2fae618bfca2be/src/ray/core_worker/test/core_worker_test.cc#L224-L237
ObjectID put(std::shared_ptr<Buffer> buffer) {
auto &driver = CoreWorkerProcess::GetCoreWorker();
Expand Down Expand Up @@ -364,16 +371,24 @@ std::unordered_map<std::string, double> get_task_required_resources() {
return worker_context.GetCurrentTask()->GetRequiredResources().GetResourceUnorderedMap();
}

void _push_back(std::vector<TaskArg *> &vector, TaskArg &el) {
vector.push_back(&el);
}

namespace jlcxx
{
// Needed for upcasting
template<> struct SuperType<LocalMemoryBuffer> { typedef Buffer type; };
template<> struct SuperType<JuliaFunctionDescriptor> { typedef FunctionDescriptorInterface type; };
template<> struct SuperType<TaskArgByReference> { typedef TaskArg type; };
template<> struct SuperType<TaskArgByValue> { typedef TaskArg type; };

// Disable generated constructors
// https://github.com/JuliaInterop/CxxWrap.jl/issues/141#issuecomment-491373720
template<> struct DefaultConstructible<LocalMemoryBuffer> : std::false_type {};
template<> struct DefaultConstructible<RayObject> : std::false_type {};
// template<> struct DefaultConstructible<JuliaFunctionDescriptor> : std::false_type {};
template<> struct DefaultConstructible<TaskArg> : std::false_type {};

// Custom finalizer to show what is being deleted. Can be useful in tracking down
// segmentation faults due to double deallocations
Expand Down Expand Up @@ -503,14 +518,45 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)

mod.method("put", &put);
mod.method("get", &get);
mod.method("_submit_task", &_submit_task);

// class ObjectReference
// message Address
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L86
mod.add_type<rpc::Address>("Address")
.constructor<>()
.method("SerializeToString", [](const rpc::Address &addr) {
std::string serialized;
addr.SerializeToString(&serialized);
return serialized;
})
.method("MessageToJsonString", [](const rpc::Address &addr) {
std::string json;
google::protobuf::util::MessageToJsonString(addr, &json);
return json;
});
mod.method("GetRpcAddress", &GetRpcAddress);

// message ObjectReference
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L500
mod.add_type<rpc::ObjectReference>("ObjectReference");
jlcxx::stl::apply_stl<rpc::ObjectReference>(mod);

// class RayObject
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_object.h#L28
mod.add_type<RayObject>("RayObject")
.method("GetData", &RayObject::GetData);

// Julia RayObject constructors make shared_ptrs
mod.method("RayObject", [] (
const std::shared_ptr<Buffer> &data,
const std::shared_ptr<Buffer> &metadata,
const std::vector<rpc::ObjectReference> &nested_refs,
bool copy_data = false) {

return std::make_shared<RayObject>(data, metadata, nested_refs, copy_data);
});
mod.method("RayObject", [] (const std::shared_ptr<Buffer> &data) {
return std::make_shared<RayObject>(data, nullptr, std::vector<rpc::ObjectReference>(), false);
});
jlcxx::stl::apply_stl<std::shared_ptr<RayObject>>(mod);

mod.add_type<Status>("Status")
Expand Down Expand Up @@ -541,4 +587,39 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
mod.method("serialize_job_config_json", &serialize_job_config_json);
mod.method("get_job_serialized_runtime_env", &get_job_serialized_runtime_env);
mod.method("get_task_required_resources", &get_task_required_resources);

mod.add_type<RayConfig>("RayConfig")
.method("RayConfigInstance", &RayConfig::instance)
.method("max_direct_call_object_size", [](RayConfig &config) {
return config.max_direct_call_object_size();
})
.method("task_rpc_inlined_bytes_limit", [](RayConfig &config) {
return config.task_rpc_inlined_bytes_limit();
})
.method("record_ref_creation_sites", [](RayConfig &config) {
return config.record_ref_creation_sites();
});
Comment on lines +587 to +595
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do these need to be lambdas? seem pretty straightforward...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure now. I'll try this without the lambdas

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails with this:

wrapper.cc:582:59: error: call to non-static member function without an object argument
        .method("max_direct_call_object_size", RayConfig::max_direct_call_object_size)
                                               ~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~


// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/task/task_util.h
mod.add_type<TaskArg>("TaskArg");
mod.method("_push_back", &_push_back);

// The Julia types `TaskArgByReference` and `TaskArgByValue` have their default finalizers
// disabled as these will later be used as `std::unique_ptr`. If these finalizers were enabled
// we would see segmentation faults due to double deallocations.
//
// Note: It is possible to create `std::unique_ptr`s in C++ and return them to Julia however
// CxxWrap is unable to compile any wrapped functions using `std::vector<std::unique_ptr<TaskArg>>`.
// We're working around this by using `std::vector<TaskArg *>`.
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370

mod.add_type<TaskArgByReference>("TaskArgByReference", jlcxx::julia_base_type<TaskArg>())
.constructor<const ObjectID &/*object_id*/,
const rpc::Address &/*owner_address*/,
const std::string &/*call_site*/>(false);

mod.add_type<TaskArgByValue>("TaskArgByValue", jlcxx::julia_base_type<TaskArg>())
.constructor<const std::shared_ptr<RayObject> &/*value*/>(false);

mod.method("_submit_task", &_submit_task);
}
11 changes: 8 additions & 3 deletions src/wrappers/any.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,17 @@ function Base.push!(v::CxxPtr{StdVector{T}}, el::T) where T <: SharedPtr{LocalMe
return push!(v, CxxRef(el))
end

# Work around CxxWrap's `push!` always dereferencing our value via `@cxxdereference`
# https://github.com/JuliaInterop/CxxWrap.jl/blob/0de5fbc5673367adc7e725cfc6e1fc6a8f9240a0/src/StdLib.jl#L78-L81
function Base.push!(v::StdVector{CxxPtr{TaskArg}}, el::CxxRef{<:TaskArg})
_push_back(v, el)
return v
end

# XXX: Need to convert julia vectors to StdVector and build the
# `std::unordered_map` for resources. This function helps us avoid having
# CxxWrap as a direct dependency in Ray.jl
function _submit_task(fd, oids::AbstractVector, serialized_runtime_env_info, resources)
# https://github.com/JuliaInterop/CxxWrap.jl/issues/367
args = isempty(oids) ? StdVector{ObjectID}() : StdVector(oids)
function _submit_task(fd, args, serialized_runtime_env_info, resources::AbstractDict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to remove the AbstractVector restriction here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The args passed in here are a StdVector so if we'd leave the type assertion in we'd have to update it. Mainly this assertion was dropped as this method should only be called if the resources argument is a AbstractDict

@debug "task resources: " resources
resources = build_resource_requests(resources)
return _submit_task(fd, args, serialized_runtime_env_info, resources)
Expand Down
Loading