diff --git a/Ray.jl/src/Ray.jl b/Ray.jl/src/Ray.jl index 4b308b2e..e4d8e45c 100644 --- a/Ray.jl/src/Ray.jl +++ b/Ray.jl/src/Ray.jl @@ -7,7 +7,7 @@ module Ray using ArgParse using Base64 -using CxxWrap: CxxRef +using CxxWrap: CxxPtr, CxxRef, StdVector using CxxWrap.StdLib: SharedPtr using JSON3 using Logging diff --git a/Ray.jl/src/object_store.jl b/Ray.jl/src/object_store.jl index 6f931efd..5de4ff76 100644 --- a/Ray.jl/src/object_store.jl +++ b/Ray.jl/src/object_store.jl @@ -4,7 +4,11 @@ Store `data` in the object store. Returns an object reference which can used to retrieve the `data` with [`Ray.get`](@ref). """ -put(data) = ray_jll.put(to_serialized_buffer(data)) +function put(data) + bytes = serialize_to_bytes(data) + buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) + return ray_jll.put(buffer) +end """ Ray.get(object_id::ObjectIDAllocated) @@ -16,24 +20,24 @@ if run in an `@async` task. If the task that generated the `ObjectID` failed with a Julia exception, the captured exception will be thrown on `get`. """ -get(oid::ray_jll.ObjectIDAllocated) = _get(ray_jll.get(oid)) -get(obj::SharedPtr{ray_jll.RayObject}) = _get(ray_jll.GetData(obj[])) +get(oid::ray_jll.ObjectIDAllocated) = _get(take!(ray_jll.get(oid))) +get(obj::SharedPtr{ray_jll.RayObject}) = _get(take!(ray_jll.GetData(obj[]))) get(x) = x -function _get(buffer) - result = from_serialized_buffer(buffer) +function _get(bytes) + result = deserialize_from_bytes(bytes) # TODO: add an option to not rethrow # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/58 result isa RayRemoteException ? throw(result) : return result end -function to_serialized_buffer(data) +function serialize_to_bytes(x) bytes = Vector{UInt8}() io = IOBuffer(bytes; write=true) - serialize(io, data) - return ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) + serialize(io, x) + return bytes end -function from_serialized_buffer(buffer) - result = deserialize(IOBuffer(take!(buffer))) +function deserialize_from_bytes(bytes) + return deserialize(IOBuffer(bytes)) end diff --git a/Ray.jl/src/runtime.jl b/Ray.jl/src/runtime.jl index 3b193a4e..cc8fc96f 100644 --- a/Ray.jl/src/runtime.jl +++ b/Ray.jl/src/runtime.jl @@ -180,12 +180,13 @@ end initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...) +# TODO: Move task related code into a "task.jl" file function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); runtime_env::Union{RuntimeEnv,Nothing}=nothing, resources::Dict{String,Float64}=Dict("CPU" => 1.0)) export_function!(FUNCTION_MANAGER[], f, get_job_id()) fd = ray_jll.function_descriptor(f) - arg_oids = map(Ray.put, flatten_args(args, kwargs)) + task_args = serialize_args(flatten_args(args, kwargs)) serialized_runtime_env_info = if !isnothing(runtime_env) _serialize(RuntimeEnvInfo(runtime_env)) @@ -193,10 +194,66 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); "" end - return GC.@preserve args ray_jll._submit_task(fd, - arg_oids, - serialized_runtime_env_info, - resources) + GC.@preserve task_args begin + return ray_jll._submit_task(fd, + transform_task_args(task_args), + serialized_runtime_env_info, + resources) + end +end + +# Adapted from `prepare_args_internal`: +# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L673 +function serialize_args(args) + ray_config = ray_jll.RayConfigInstance() + put_threshold = ray_jll.max_direct_call_object_size(ray_config) + rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config) + record_call_site = ray_jll.record_ref_creation_sites(ray_config) + + worker = ray_jll.GetCoreWorker() + rpc_address = ray_jll.GetRpcAddress(worker) + + total_inlined = 0 + + # TODO: Ideally would be `ray_jll.TaskArg[]`: + # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/79 + task_args = Any[] + for arg in args + # Note: The Python `prepare_args_internal` function checks if the `arg` is an + # `ObjectRef` and in that case uses the object ID to directly make a + # `TaskArgByReference`. However, as the `args` here are flattened the `arg` will + # always be a `Pair` (or a list in Python). I suspect this Python code path just + # dead code so we'll exclude it from ours. + + serialized_arg = serialize_to_bytes(arg) + serialized_arg_size = sizeof(serialized_arg) + buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true) + + # Inline arguments which are small and if there is room + task_arg = if (serialized_arg_size <= put_threshold && + serialized_arg_size + total_inlined <= rpc_inline_threshold) + + total_inlined += serialized_arg_size + ray_jll.TaskArgByValue(ray_jll.RayObject(buffer)) + else + oid = ray_jll.put(buffer) + # TODO: Add test for populating `call_site` + call_site = record_call_site ? sprint(Base.show_backtrace, backtrace()) : "" + ray_jll.TaskArgByReference(oid, rpc_address, call_site) + end + + push!(task_args, task_arg) + end + + return task_args +end + +function transform_task_args(task_args) + task_arg_ptrs = StdVector{CxxPtr{ray_jll.TaskArg}}() + for task_arg in task_args + push!(task_arg_ptrs, CxxPtr(task_arg)) + end + return task_arg_ptrs end function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, @@ -255,7 +312,9 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, # TODO: support multiple return values # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54 - push!(returns, to_serialized_buffer(result)) + bytes = serialize_to_bytes(result) + buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) + push!(returns, buffer) return nothing end diff --git a/Ray.jl/test/runtests.jl b/Ray.jl/test/runtests.jl index 001ab1f1..095a40ee 100644 --- a/Ray.jl/test/runtests.jl +++ b/Ray.jl/test/runtests.jl @@ -1,4 +1,6 @@ using Aqua +using CxxWrap: CxxPtr, StdVector +using CxxWrap.StdLib: UniquePtr using Ray using Serialization using Test diff --git a/Ray.jl/test/runtime.jl b/Ray.jl/test/runtime.jl index 97e2de7e..a4df331e 100644 --- a/Ray.jl/test/runtime.jl +++ b/Ray.jl/test/runtime.jl @@ -78,5 +78,47 @@ @test !contains(stderr_logs, "Constructing CoreWorkerProcess") end end +end + +@testset "serialize_args" begin + ray_config = ray_jll.RayConfigInstance() + put_threshold = ray_jll.max_direct_call_object_size(ray_config) + rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config) + + # The `flatten_args` function uses `:_` as the key for positional arguments. + serialization_overhead = begin + sizeof(Ray.serialize_to_bytes(:_ => zeros(UInt8, put_threshold))) - put_threshold + end + + @testset "put threshold" begin + a = :_ => zeros(UInt8, put_threshold - serialization_overhead) + b = :_ => zeros(UInt8, put_threshold - serialization_overhead + 1) + task_args = Ray.serialize_args([a, b]) + @test length(task_args) == 2 + @test task_args[1] isa ray_jll.TaskArgByValue + @test task_args[2] isa ray_jll.TaskArgByReference + map(UniquePtr ∘ CxxPtr , task_args) # Add finalizer for memory cleanup + + task_args = Ray.serialize_args([b, a]) + @test length(task_args) == 2 + @test task_args[1] isa ray_jll.TaskArgByReference + @test task_args[2] isa ray_jll.TaskArgByValue + map(UniquePtr ∘ CxxPtr , task_args) # Add finalizer for memory cleanup + end + + @testset "inline threshold" begin + a = :_ => zeros(UInt8, put_threshold - serialization_overhead) + args = fill(a, rpc_inline_threshold ÷ put_threshold + 1) + task_args = Ray.serialize_args(args) + @test all(t -> t isa ray_jll.TaskArgByValue, task_args[1:(end - 1)]) + @test task_args[end] isa ray_jll.TaskArgByReference + map(UniquePtr ∘ CxxPtr , task_args) # Add finalizer for memory cleanup + end +end +@testset "transform_task_args" begin + task_args = Ray.serialize_args(Ray.flatten_args([1, 2, 3], (;))) + result = Ray.transform_task_args(task_args) + @test result isa StdVector{CxxPtr{ray_jll.TaskArg}} + map(UniquePtr ∘ CxxPtr , task_args) # Add finalizer for memory cleanup end diff --git a/deps/wrapper.cc b/deps/wrapper.cc index 78e427cc..54a61ee2 100644 --- a/deps/wrapper.cc +++ b/deps/wrapper.cc @@ -140,7 +140,7 @@ std::vector> cast_to_task_args(void *ptr) { } ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, - const std::vector &object_ids, + const std::vector &task_args, const std::string &serialized_runtime_env_info, const std::unordered_map &resources) { @@ -149,9 +149,11 @@ ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, ray::FunctionDescriptor func_descriptor = std::make_shared(jl_func_descriptor); RayFunction func(Language::JULIA, func_descriptor); + // TODO: Passing in a `std::vector>` from Julia may currently be impossible due to: + // https://github.com/JuliaInterop/CxxWrap.jl/issues/370 std::vector> args; - for (auto & obj_id : object_ids) { - args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/"")); + for (auto &task_arg : task_args) { + args.emplace_back(task_arg); } // TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87 @@ -355,16 +357,24 @@ std::unordered_map get_task_required_resources() { return worker_context.GetCurrentTask()->GetRequiredResources().GetResourceUnorderedMap(); } +void _push_back(std::vector &vector, TaskArg *el) { + vector.push_back(el); +} + namespace jlcxx { // Needed for upcasting template<> struct SuperType { typedef Buffer type; }; template<> struct SuperType { typedef FunctionDescriptorInterface type; }; + template<> struct SuperType { typedef TaskArg type; }; + template<> struct SuperType { typedef TaskArg type; }; // Disable generated constructors // https://github.com/JuliaInterop/CxxWrap.jl/issues/141#issuecomment-491373720 template<> struct DefaultConstructible : std::false_type {}; + template<> struct DefaultConstructible : std::false_type {}; // template<> struct DefaultConstructible : std::false_type {}; + template<> struct DefaultConstructible : std::false_type {}; // Custom finalizer to show what is being deleted. Can be useful in tracking down // segmentation faults due to double deallocations @@ -420,12 +430,6 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) .method("Binary", &TaskID::Binary) .method("Hex", &TaskID::Hex); - // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284 - mod.add_type("CoreWorker") - .method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId) - .method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId); - mod.method("_GetCoreWorker", &_GetCoreWorker); - mod.method("initialize_driver", &initialize_driver); mod.method("shutdown_driver", &shutdown_driver); mod.method("initialize_worker", &initialize_worker); @@ -497,15 +501,51 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) mod.method("put", &put); mod.method("get", &get); - mod.method("_submit_task", &_submit_task); + + // message Address + // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L86 + mod.add_type("Address") + .constructor<>() + .method("SerializeToString", [](const rpc::Address &addr) { + std::string serialized; + addr.SerializeToString(&serialized); + return serialized; + }) + .method("MessageToJsonString", [](const rpc::Address &addr) { + std::string json; + google::protobuf::util::MessageToJsonString(addr, &json); + return json; + }); + + // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284 + mod.add_type("CoreWorker") + .method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId) + .method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId) + .method("GetRpcAddress", &ray::core::CoreWorker::GetRpcAddress); + mod.method("_GetCoreWorker", &_GetCoreWorker); // message ObjectReference // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L500 mod.add_type("ObjectReference"); jlcxx::stl::apply_stl(mod); + // class RayObject + // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_object.h#L28 mod.add_type("RayObject") .method("GetData", &RayObject::GetData); + + // Julia RayObject constructors make shared_ptrs + mod.method("RayObject", [] ( + const std::shared_ptr &data, + const std::shared_ptr &metadata, + const std::vector &nested_refs, + bool copy_data = false) { + + return std::make_shared(data, metadata, nested_refs, copy_data); + }); + mod.method("RayObject", [] (const std::shared_ptr &data) { + return std::make_shared(data, nullptr, std::vector(), false); + }); jlcxx::stl::apply_stl>(mod); mod.add_type("Status") @@ -536,4 +576,50 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) mod.method("serialize_job_config_json", &serialize_job_config_json); mod.method("get_job_serialized_runtime_env", &get_job_serialized_runtime_env); mod.method("get_task_required_resources", &get_task_required_resources); + + // class RayConfig + // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/ray_config.h#L60 + // + // Lambdas required here as otherwise we see the following error: + // "error: call to non-static member function without an object argument" + mod.add_type("RayConfig") + .method("RayConfigInstance", &RayConfig::instance) + .method("max_direct_call_object_size", [](RayConfig &config) { + return config.max_direct_call_object_size(); + }) + .method("task_rpc_inlined_bytes_limit", [](RayConfig &config) { + return config.task_rpc_inlined_bytes_limit(); + }) + .method("record_ref_creation_sites", [](RayConfig &config) { + return config.record_ref_creation_sites(); + }); + + // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/task/task_util.h + mod.add_type("TaskArg"); + mod.method("_push_back", &_push_back); + + // The Julia types `TaskArgByReference` and `TaskArgByValue` have their default finalizers + // disabled as these will later be used as `std::unique_ptr`. If these finalizers were enabled + // we would see segmentation faults due to double deallocations. + // + // Note: It is possible to create `std::unique_ptr`s in C++ and return them to Julia however + // CxxWrap is unable to compile any wrapped functions using `std::vector>`. + // We're working around this by using `std::vector`. + // https://github.com/JuliaInterop/CxxWrap.jl/issues/370 + + mod.add_type("TaskArgByReference", jlcxx::julia_base_type()) + .constructor(false) + .method("unique_ptr", [](TaskArgByReference *t) { + return std::unique_ptr(t); + }); + + mod.add_type("TaskArgByValue", jlcxx::julia_base_type()) + .constructor &/*value*/>(false) + .method("unique_ptr", [](TaskArgByValue *t) { + return std::unique_ptr(t); + }); + + mod.method("_submit_task", &_submit_task); } diff --git a/src/wrappers/any.jl b/src/wrappers/any.jl index 2f1cb716..03c2d0a4 100644 --- a/src/wrappers/any.jl +++ b/src/wrappers/any.jl @@ -135,6 +135,16 @@ function GetCoreWorker() return CORE_WORKER[] end +##### +##### TaskArg +##### + +function CxxWrap.StdLib.UniquePtr(ptr::Union{Ptr{Nothing}, + CxxPtr{<:TaskArgByReference}, + CxxPtr{<:TaskArgByValue}}) + return unique_ptr(ptr) +end + ##### ##### Upstream fixes ##### @@ -157,12 +167,17 @@ function Base.push!(v::CxxPtr{StdVector{T}}, el::T) where T <: SharedPtr{LocalMe return push!(v, CxxRef(el)) end +# Work around CxxWrap's `push!` always dereferencing our value via `@cxxdereference` +# https://github.com/JuliaInterop/CxxWrap.jl/blob/0de5fbc5673367adc7e725cfc6e1fc6a8f9240a0/src/StdLib.jl#L78-L81 +function Base.push!(v::StdVector{CxxPtr{TaskArg}}, el::CxxPtr{<:TaskArg}) + _push_back(v, el) + return v +end + # XXX: Need to convert julia vectors to StdVector and build the # `std::unordered_map` for resources. This function helps us avoid having # CxxWrap as a direct dependency in Ray.jl -function _submit_task(fd, oids::AbstractVector, serialized_runtime_env_info, resources) - # https://github.com/JuliaInterop/CxxWrap.jl/issues/367 - args = isempty(oids) ? StdVector{ObjectID}() : StdVector(oids) +function _submit_task(fd, args, serialized_runtime_env_info, resources::AbstractDict) @debug "task resources: " resources resources = build_resource_requests(resources) return _submit_task(fd, args, serialized_runtime_env_info, resources)