Skip to content

Commit

Permalink
Support inlining task arguments (#72)
Browse files Browse the repository at this point in the history
* Support RayConfig

* Rough first pass on prepare_task_args

* Support rpc::Address

* Support TaskArg

* Support constructable RayObject

* Support GetCoreWorker

* Attempt to pass back unique_ptr TaskArg

* Lots of experimentation

* Functional with segfaults

* Disable default finalizer

* Use TaskArgByReference in prepare_task_args

* Cleanup

* Experiment with shared_ptr

Trying to get `push!` for free but converting a shared_ptr to a
unique_ptr is trouble

* Extend push! with our custom _push_back

* Working submit_task

* Refactoring

* Generate call_site string

* Add tests

* Drop unused GetCoreWorker

* Add TODO about task_args being an Any vector

* Manual cleanup of TaskArgs

* Add comment about using lambdas for RayConfig

* Comment about use of understore in serialize_args tests

* Indentation

Co-authored-by: Dave Kleinschmidt <dave.f.kleinschmidt@gmail.com>

* Push CxxPtr instead of CxxRef

---------

Co-authored-by: Dave Kleinschmidt <dave.f.kleinschmidt@gmail.com>
  • Loading branch information
omus and kleinschmidt authored Sep 5, 2023
1 parent c32e03d commit c630b0f
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Ray.jl/src/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Ray

using ArgParse
using Base64
using CxxWrap: CxxRef
using CxxWrap: CxxPtr, CxxRef, StdVector
using CxxWrap.StdLib: SharedPtr
using JSON3
using Logging
Expand Down
24 changes: 14 additions & 10 deletions Ray.jl/src/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
Store `data` in the object store. Returns an object reference which can used to retrieve
the `data` with [`Ray.get`](@ref).
"""
put(data) = ray_jll.put(to_serialized_buffer(data))
function put(data)
bytes = serialize_to_bytes(data)
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
return ray_jll.put(buffer)
end

"""
Ray.get(object_id::ObjectIDAllocated)
Expand All @@ -16,24 +20,24 @@ if run in an `@async` task.
If the task that generated the `ObjectID` failed with a Julia exception, the
captured exception will be thrown on `get`.
"""
get(oid::ray_jll.ObjectIDAllocated) = _get(ray_jll.get(oid))
get(obj::SharedPtr{ray_jll.RayObject}) = _get(ray_jll.GetData(obj[]))
get(oid::ray_jll.ObjectIDAllocated) = _get(take!(ray_jll.get(oid)))
get(obj::SharedPtr{ray_jll.RayObject}) = _get(take!(ray_jll.GetData(obj[])))
get(x) = x

function _get(buffer)
result = from_serialized_buffer(buffer)
function _get(bytes)
result = deserialize_from_bytes(bytes)
# TODO: add an option to not rethrow
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/58
result isa RayRemoteException ? throw(result) : return result
end

function to_serialized_buffer(data)
function serialize_to_bytes(x)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
serialize(io, data)
return ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
serialize(io, x)
return bytes
end

function from_serialized_buffer(buffer)
result = deserialize(IOBuffer(take!(buffer)))
function deserialize_from_bytes(bytes)
return deserialize(IOBuffer(bytes))
end
71 changes: 65 additions & 6 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,80 @@ end

initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...)

# TODO: Move task related code into a "task.jl" file
function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
runtime_env::Union{RuntimeEnv,Nothing}=nothing,
resources::Dict{String,Float64}=Dict("CPU" => 1.0))
export_function!(FUNCTION_MANAGER[], f, get_job_id())
fd = ray_jll.function_descriptor(f)
arg_oids = map(Ray.put, flatten_args(args, kwargs))
task_args = serialize_args(flatten_args(args, kwargs))

serialized_runtime_env_info = if !isnothing(runtime_env)
_serialize(RuntimeEnvInfo(runtime_env))
else
""
end

return GC.@preserve args ray_jll._submit_task(fd,
arg_oids,
serialized_runtime_env_info,
resources)
GC.@preserve task_args begin
return ray_jll._submit_task(fd,
transform_task_args(task_args),
serialized_runtime_env_info,
resources)
end
end

# Adapted from `prepare_args_internal`:
# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L673
function serialize_args(args)
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)
record_call_site = ray_jll.record_ref_creation_sites(ray_config)

worker = ray_jll.GetCoreWorker()
rpc_address = ray_jll.GetRpcAddress(worker)

total_inlined = 0

# TODO: Ideally would be `ray_jll.TaskArg[]`:
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/79
task_args = Any[]
for arg in args
# Note: The Python `prepare_args_internal` function checks if the `arg` is an
# `ObjectRef` and in that case uses the object ID to directly make a
# `TaskArgByReference`. However, as the `args` here are flattened the `arg` will
# always be a `Pair` (or a list in Python). I suspect this Python code path just
# dead code so we'll exclude it from ours.

serialized_arg = serialize_to_bytes(arg)
serialized_arg_size = sizeof(serialized_arg)
buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true)

# Inline arguments which are small and if there is room
task_arg = if (serialized_arg_size <= put_threshold &&
serialized_arg_size + total_inlined <= rpc_inline_threshold)

total_inlined += serialized_arg_size
ray_jll.TaskArgByValue(ray_jll.RayObject(buffer))
else
oid = ray_jll.put(buffer)
# TODO: Add test for populating `call_site`
call_site = record_call_site ? sprint(Base.show_backtrace, backtrace()) : ""
ray_jll.TaskArgByReference(oid, rpc_address, call_site)
end

push!(task_args, task_arg)
end

return task_args
end

function transform_task_args(task_args)
task_arg_ptrs = StdVector{CxxPtr{ray_jll.TaskArg}}()
for task_arg in task_args
push!(task_arg_ptrs, CxxPtr(task_arg))
end
return task_arg_ptrs
end

function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
Expand Down Expand Up @@ -255,7 +312,9 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,

# TODO: support multiple return values
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54
push!(returns, to_serialized_buffer(result))
bytes = serialize_to_bytes(result)
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
push!(returns, buffer)

return nothing
end
Expand Down
2 changes: 2 additions & 0 deletions Ray.jl/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Aqua
using CxxWrap: CxxPtr, StdVector
using CxxWrap.StdLib: UniquePtr
using Ray
using Serialization
using Test
Expand Down
42 changes: 42 additions & 0 deletions Ray.jl/test/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,47 @@
@test !contains(stderr_logs, "Constructing CoreWorkerProcess")
end
end
end

@testset "serialize_args" begin
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)

# The `flatten_args` function uses `:_` as the key for positional arguments.
serialization_overhead = begin
sizeof(Ray.serialize_to_bytes(:_ => zeros(UInt8, put_threshold))) - put_threshold
end

@testset "put threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
b = :_ => zeros(UInt8, put_threshold - serialization_overhead + 1)
task_args = Ray.serialize_args([a, b])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByValue
@test task_args[2] isa ray_jll.TaskArgByReference
map(UniquePtr CxxPtr , task_args) # Add finalizer for memory cleanup

task_args = Ray.serialize_args([b, a])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByReference
@test task_args[2] isa ray_jll.TaskArgByValue
map(UniquePtr CxxPtr , task_args) # Add finalizer for memory cleanup
end

@testset "inline threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
args = fill(a, rpc_inline_threshold ÷ put_threshold + 1)
task_args = Ray.serialize_args(args)
@test all(t -> t isa ray_jll.TaskArgByValue, task_args[1:(end - 1)])
@test task_args[end] isa ray_jll.TaskArgByReference
map(UniquePtr CxxPtr , task_args) # Add finalizer for memory cleanup
end
end

@testset "transform_task_args" begin
task_args = Ray.serialize_args(Ray.flatten_args([1, 2, 3], (;)))
result = Ray.transform_task_args(task_args)
@test result isa StdVector{CxxPtr{ray_jll.TaskArg}}
map(UniquePtr CxxPtr , task_args) # Add finalizer for memory cleanup
end
106 changes: 96 additions & 10 deletions deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ std::vector<std::shared_ptr<RayObject>> cast_to_task_args(void *ptr) {
}

ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
const std::vector<ObjectID> &object_ids,
const std::vector<TaskArg *> &task_args,
const std::string &serialized_runtime_env_info,
const std::unordered_map<std::string, double> &resources) {

Expand All @@ -149,9 +149,11 @@ ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
ray::FunctionDescriptor func_descriptor = std::make_shared<ray::JuliaFunctionDescriptor>(jl_func_descriptor);
RayFunction func(Language::JULIA, func_descriptor);

// TODO: Passing in a `std::vector<std::unique_ptr<TaskArg>>` from Julia may currently be impossible due to:
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370
std::vector<std::unique_ptr<TaskArg>> args;
for (auto & obj_id : object_ids) {
args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/""));
for (auto &task_arg : task_args) {
args.emplace_back(task_arg);
}

// TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87
Expand Down Expand Up @@ -355,16 +357,24 @@ std::unordered_map<std::string, double> get_task_required_resources() {
return worker_context.GetCurrentTask()->GetRequiredResources().GetResourceUnorderedMap();
}

void _push_back(std::vector<TaskArg *> &vector, TaskArg *el) {
vector.push_back(el);
}

namespace jlcxx
{
// Needed for upcasting
template<> struct SuperType<LocalMemoryBuffer> { typedef Buffer type; };
template<> struct SuperType<JuliaFunctionDescriptor> { typedef FunctionDescriptorInterface type; };
template<> struct SuperType<TaskArgByReference> { typedef TaskArg type; };
template<> struct SuperType<TaskArgByValue> { typedef TaskArg type; };

// Disable generated constructors
// https://github.com/JuliaInterop/CxxWrap.jl/issues/141#issuecomment-491373720
template<> struct DefaultConstructible<LocalMemoryBuffer> : std::false_type {};
template<> struct DefaultConstructible<RayObject> : std::false_type {};
// template<> struct DefaultConstructible<JuliaFunctionDescriptor> : std::false_type {};
template<> struct DefaultConstructible<TaskArg> : std::false_type {};

// Custom finalizer to show what is being deleted. Can be useful in tracking down
// segmentation faults due to double deallocations
Expand Down Expand Up @@ -420,12 +430,6 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
.method("Binary", &TaskID::Binary)
.method("Hex", &TaskID::Hex);

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284
mod.add_type<ray::core::CoreWorker>("CoreWorker")
.method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId)
.method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId);
mod.method("_GetCoreWorker", &_GetCoreWorker);

mod.method("initialize_driver", &initialize_driver);
mod.method("shutdown_driver", &shutdown_driver);
mod.method("initialize_worker", &initialize_worker);
Expand Down Expand Up @@ -497,15 +501,51 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)

mod.method("put", &put);
mod.method("get", &get);
mod.method("_submit_task", &_submit_task);

// message Address
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L86
mod.add_type<rpc::Address>("Address")
.constructor<>()
.method("SerializeToString", [](const rpc::Address &addr) {
std::string serialized;
addr.SerializeToString(&serialized);
return serialized;
})
.method("MessageToJsonString", [](const rpc::Address &addr) {
std::string json;
google::protobuf::util::MessageToJsonString(addr, &json);
return json;
});

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284
mod.add_type<ray::core::CoreWorker>("CoreWorker")
.method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId)
.method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId)
.method("GetRpcAddress", &ray::core::CoreWorker::GetRpcAddress);
mod.method("_GetCoreWorker", &_GetCoreWorker);

// message ObjectReference
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L500
mod.add_type<rpc::ObjectReference>("ObjectReference");
jlcxx::stl::apply_stl<rpc::ObjectReference>(mod);

// class RayObject
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_object.h#L28
mod.add_type<RayObject>("RayObject")
.method("GetData", &RayObject::GetData);

// Julia RayObject constructors make shared_ptrs
mod.method("RayObject", [] (
const std::shared_ptr<Buffer> &data,
const std::shared_ptr<Buffer> &metadata,
const std::vector<rpc::ObjectReference> &nested_refs,
bool copy_data = false) {

return std::make_shared<RayObject>(data, metadata, nested_refs, copy_data);
});
mod.method("RayObject", [] (const std::shared_ptr<Buffer> &data) {
return std::make_shared<RayObject>(data, nullptr, std::vector<rpc::ObjectReference>(), false);
});
jlcxx::stl::apply_stl<std::shared_ptr<RayObject>>(mod);

mod.add_type<Status>("Status")
Expand Down Expand Up @@ -536,4 +576,50 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
mod.method("serialize_job_config_json", &serialize_job_config_json);
mod.method("get_job_serialized_runtime_env", &get_job_serialized_runtime_env);
mod.method("get_task_required_resources", &get_task_required_resources);

// class RayConfig
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_config.h#L60
//
// Lambdas required here as otherwise we see the following error:
// "error: call to non-static member function without an object argument"
mod.add_type<RayConfig>("RayConfig")
.method("RayConfigInstance", &RayConfig::instance)
.method("max_direct_call_object_size", [](RayConfig &config) {
return config.max_direct_call_object_size();
})
.method("task_rpc_inlined_bytes_limit", [](RayConfig &config) {
return config.task_rpc_inlined_bytes_limit();
})
.method("record_ref_creation_sites", [](RayConfig &config) {
return config.record_ref_creation_sites();
});

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/task/task_util.h
mod.add_type<TaskArg>("TaskArg");
mod.method("_push_back", &_push_back);

// The Julia types `TaskArgByReference` and `TaskArgByValue` have their default finalizers
// disabled as these will later be used as `std::unique_ptr`. If these finalizers were enabled
// we would see segmentation faults due to double deallocations.
//
// Note: It is possible to create `std::unique_ptr`s in C++ and return them to Julia however
// CxxWrap is unable to compile any wrapped functions using `std::vector<std::unique_ptr<TaskArg>>`.
// We're working around this by using `std::vector<TaskArg *>`.
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370

mod.add_type<TaskArgByReference>("TaskArgByReference", jlcxx::julia_base_type<TaskArg>())
.constructor<const ObjectID &/*object_id*/,
const rpc::Address &/*owner_address*/,
const std::string &/*call_site*/>(false)
.method("unique_ptr", [](TaskArgByReference *t) {
return std::unique_ptr<TaskArgByReference>(t);
});

mod.add_type<TaskArgByValue>("TaskArgByValue", jlcxx::julia_base_type<TaskArg>())
.constructor<const std::shared_ptr<RayObject> &/*value*/>(false)
.method("unique_ptr", [](TaskArgByValue *t) {
return std::unique_ptr<TaskArgByValue>(t);
});

mod.method("_submit_task", &_submit_task);
}
Loading

0 comments on commit c630b0f

Please sign in to comment.