Skip to content

Commit

Permalink
Support serializing ObjectRef (#85)
Browse files Browse the repository at this point in the history
* Support serializing ObjectRef

* Use custom RaySerializer

* Move CoreWorker wrapper to support GetObjectRefs

* Add function for creation of Buffer nullptr

* Create RayObject argument with inlined refs

* Test using object ref argument

* Serialize with header

* Implement reset_state

* Refactor

* Rename file to ray_serializer.jl

* Add serialize_to_bytes/deserialize_from_bytes tests

* Add RaySerializer tests

* deserialize_from_bytes one liner

* Add note about indirect testing of serialized header

* Add hash support for ObjectRef

---------

Co-authored-by: Dave Kleinschmidt <dave.f.kleinschmidt@gmail.com>
  • Loading branch information
omus and kleinschmidt authored Sep 7, 2023
1 parent 9004702 commit 5c8bfe6
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 19 deletions.
19 changes: 12 additions & 7 deletions ray_julia_jll/deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,11 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
.method("Size", &Buffer::Size)
.method("OwnsData", &Buffer::OwnsData)
.method("IsPlasmaBuffer", &Buffer::IsPlasmaBuffer);
mod.method("BufferFromNull", [] () {
return std::shared_ptr<Buffer>(nullptr);
});
jlcxx::stl::apply_stl<std::shared_ptr<Buffer>>(mod);

mod.add_type<LocalMemoryBuffer>("LocalMemoryBuffer", jlcxx::julia_base_type<Buffer>());
mod.method("LocalMemoryBuffer", [] (uint8_t *data, size_t size, bool copy_data = false) {
return std::make_shared<LocalMemoryBuffer>(data, size, copy_data);
Expand All @@ -518,13 +522,6 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
return json;
});

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284
mod.add_type<ray::core::CoreWorker>("CoreWorker")
.method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId)
.method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId)
.method("GetRpcAddress", &ray::core::CoreWorker::GetRpcAddress);
mod.method("_GetCoreWorker", &_GetCoreWorker);

// message ObjectReference
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/protobuf/common.proto#L500
mod.add_type<rpc::ObjectReference>("ObjectReference");
Expand All @@ -549,6 +546,14 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
});
jlcxx::stl::apply_stl<std::shared_ptr<RayObject>>(mod);

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L284
mod.add_type<ray::core::CoreWorker>("CoreWorker")
.method("GetCurrentJobId", &ray::core::CoreWorker::GetCurrentJobId)
.method("GetCurrentTaskId", &ray::core::CoreWorker::GetCurrentTaskId)
.method("GetRpcAddress", &ray::core::CoreWorker::GetRpcAddress)
.method("GetObjectRefs", &ray::core::CoreWorker::GetObjectRefs);
mod.method("_GetCoreWorker", &_GetCoreWorker);

mod.method("put", &put);
mod.method("get", &get);

Expand Down
6 changes: 6 additions & 0 deletions ray_julia_jll/src/wrappers/any.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ function GetCoreWorker()
return CORE_WORKER[]
end

#####
##### Buffer
#####

NullPtr(::Type{Buffer}) = BufferFromNull()

#####
##### ObjectID
#####
Expand Down
4 changes: 3 additions & 1 deletion src/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ using JSON3
using Logging
using LoggingExtras
using Pkg
using Serialization
using Serialization: Serialization, AbstractSerializer, Serializer, deserialize,
reset_state, serialize, serialize_type, writeheader

import ray_julia_jll as ray_jll

Expand All @@ -24,6 +25,7 @@ include("runtime_env.jl")
include("remote_function.jl")
include("runtime.jl")
include("object_ref.jl")
include("ray_serializer.jl")
include("object_store.jl")

end
17 changes: 17 additions & 0 deletions src/object_ref.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,24 @@ ObjectRef(hex_str::AbstractString) = ObjectRef(ray_jll.FromHex(ray_jll.ObjectID,
hex_identifier(obj_ref::ObjectRef) = String(ray_jll.Hex(obj_ref.oid))
Base.:(==)(a::ObjectRef, b::ObjectRef) = hex_identifier(a) == hex_identifier(b)

function Base.hash(obj_ref::ObjectRef, h::UInt)
h = hash(ObjectRef, h)
h = hash(hex_identifier(obj_ref), h)
return h
end

function Base.show(io::IO, obj_ref::ObjectRef)
write(io, "ObjectRef(\"", hex_identifier(obj_ref), "\")")
return nothing
end

# We cannot serialize pointers between processes
function Serialization.serialize(s::AbstractSerializer, obj_ref::ObjectRef)
serialize_type(s, typeof(obj_ref))
serialize(s, hex_identifier(obj_ref))
end

function Serialization.deserialize(s::AbstractSerializer, ::Type{ObjectRef})
hex_str = deserialize(s)
return ObjectRef(hex_str)
end
8 changes: 0 additions & 8 deletions src/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,3 @@ function _get(bytes)
result isa RayRemoteException ? throw(result) : return result
end

function serialize_to_bytes(x)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
serialize(io, x)
return bytes
end

deserialize_from_bytes(bytes) = deserialize(IOBuffer(bytes))
48 changes: 48 additions & 0 deletions src/ray_serializer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
mutable struct RaySerializer{I<:IO} <: AbstractSerializer
# Fields required by all AbstractSerializers
io::I
counter::Int
table::IdDict{Any,Any}
pending_refs::Vector{Int}

# Inlined object references encountered during serializing
object_refs::Set{ObjectRef}

function RaySerializer{I}(io::I) where I<:IO
return new(io, 0, IdDict(), Int[], Set{ObjectRef}())
end
end

RaySerializer(io::IO) = RaySerializer{typeof(io)}(io)
RaySerializer(bytes::Vector{UInt8}) = RaySerializer{IOBuffer}(IOBuffer(bytes; write=true))

function Base.getproperty(s::RaySerializer, f::Symbol)
if f === :object_ids
return Set(getproperty.(s.object_refs, :oid))
else
return getfield(s, f)
end
end

function Serialization.reset_state(s::RaySerializer)
empty!(s.object_refs)
return invoke(reset_state, Tuple{AbstractSerializer}, s)
end

function Serialization.serialize(s::RaySerializer, obj_ref::ObjectRef)
push!(s.object_refs, obj_ref)
return invoke(serialize, Tuple{AbstractSerializer, ObjectRef}, s, obj_ref)
end

# As we are just throwing away the Serializer we can just avoid collecting the inlined
# object references
function serialize_to_bytes(x)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
s = Serializer(io)
writeheader(s)
serialize(s, x)
return bytes
end

deserialize_from_bytes(bytes::Vector{UInt8}) = deserialize(Serializer(IOBuffer(bytes)))
11 changes: 9 additions & 2 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,17 @@ function serialize_args(args)
# always be a `Pair` (or a list in Python). I suspect this Python code path just
# dead code so we'll exclude it from ours.

serialized_arg = serialize_to_bytes(arg)
serialized_arg = Vector{UInt8}()
serializer = RaySerializer(serialized_arg)
writeheader(serializer)
serialize(serializer, arg)
serialized_arg_size = sizeof(serialized_arg)

buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true)
ray_obj = ray_jll.RayObject(buffer)
metadata = ray_jll.NullPtr(ray_jll.Buffer)
inlined_ids = collect(serializer.object_ids)
inlined_refs = ray_jll.GetObjectRefs(worker, StdVector(inlined_ids))
ray_obj = ray_jll.RayObject(buffer, metadata, inlined_refs, false)

# Inline arguments which are small and if there is room
task_arg = if (serialized_arg_size <= put_threshold &&
Expand Down
16 changes: 15 additions & 1 deletion test/object_ref.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
function serialize_deserialize(x)
io = IOBuffer()
serialize(io, x)
seekstart(io)
return deserialize(io)
end

@testset "ObjectRef" begin
@testset "basic" begin
hex_str = "f" ^ (2 * 28)
obj_ref = ObjectRef(hex_str)
@test obj_ref == ObjectRef(hex_str)
@test Ray.hex_identifier(obj_ref) == hex_str
@test obj_ref == ObjectRef(hex_str)
@test hash(obj_ref) == hash(ObjectRef(hex_str))
end

@testset "show" begin
hex_str = "f" ^ (2 * 28)
obj_ref = ObjectRef(hex_str)
@test sprint(show, obj_ref) == "ObjectRef(\"$hex_str\")"
end

@testset "serialize/deserialize" begin
obj_ref1 = ObjectRef(ray_jll.FromRandom(ray_jll.ObjectID))
obj_ref2 = serialize_deserialize(obj_ref1)
@test obj_ref1 == obj_ref2
end
end
62 changes: 62 additions & 0 deletions test/ray_serializer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@testset "RaySerializer" begin
@testset "byte constructor" begin
bytes = Vector{UInt8}()
s = Ray.RaySerializer(bytes)
@test isempty(bytes)

serialize(s, 1)
@test !isempty(bytes)
end

@testset "object_ids property" begin
s = Ray.RaySerializer(IOBuffer())
@test s.object_ids isa Set{<:ray_jll.ObjectID}
end

@testset "inlined object refs" begin
oids = [ray_jll.FromRandom(ray_jll.ObjectID) for _ in 1:3]
obj_refs = map(ObjectRef, oids)
x = [1, 2, obj_refs...]

s = Ray.RaySerializer(IOBuffer())
serialize(s, x)

@test s.object_refs == Set(obj_refs)
@test s.object_ids == Set(oids)
end

@testset "reset_state" begin
obj_ref = ObjectRef(ray_jll.FromRandom(ray_jll.ObjectID))
s = Ray.RaySerializer(IOBuffer())
serialize(s, obj_ref)
@test !isempty(s.object_refs)

Serialization.reset_state(s)
@test isempty(s.object_refs)
end
end

@testset "serialize_to_bytes / deserialize_from_bytes" begin
@testset "roundtrip" begin
x = [1, 2, 3]
bytes = Ray.serialize_to_bytes(x)
@test bytes isa Vector{UInt8}
@test !isempty(bytes)

result = Ray.deserialize_from_bytes(bytes)
@test typeof(result) == typeof(x)
@test result == x
end

# TODO: Investigate if want to include the serialization header
@testset "serialize with header" begin
x = 123
bytes = Ray.serialize_to_bytes(x)

s = Serializer(IOBuffer(bytes))
b = Int32(read(s.io, UInt8)::UInt8)
@test b == Serialization.HEADER_TAG
Serialization.readheader(s) # Throws if header not present
@test deserialize(s) == x
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("utils.jl")
end

include("object_ref.jl")
include("ray_serializer.jl")
include("runtime_env.jl")
include("remote_function.jl")

Expand Down
4 changes: 4 additions & 0 deletions test/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ end
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)

# The `flatten_args` function uses `:_` as the key for positional arguments.
#
# Note: We are indirectly testing that `serialize_args` and `serialize_to_bytes` both
# match each other regarding including a serialization header. If there is a mismatch
# then the tests below will fail.
serialization_overhead = begin
sizeof(Ray.serialize_to_bytes(:_ => zeros(UInt8, put_threshold))) - put_threshold
end
Expand Down
7 changes: 7 additions & 0 deletions test/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
@test e isa Ray.RayRemoteException
@test e.captured.ex == ErrorException("AHHHHH")
end

# object refs as arguments
obj_ref1 = Ray.put(1)
obj_ref2 = Ray.submit_task(identity, (obj_ref1,))
@test obj_ref2 != obj_ref1
@test Ray.get(obj_ref2) == obj_ref1
@test Ray.get(Ray.get(obj_ref2)) == 1
end

@testset "Task spawning a task" begin
Expand Down

0 comments on commit 5c8bfe6

Please sign in to comment.