Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omus committed Sep 1, 2023
1 parent 7c9e252 commit d171b9d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 13 deletions.
37 changes: 24 additions & 13 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,38 +195,41 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
resources::Dict{String,Float64}=Dict("CPU" => 1.0))
export_function!(FUNCTION_MANAGER[], f, get_job_id())
fd = ray_jll.function_descriptor(f)
task_args = prepare_task_args(flatten_args(args, kwargs))
task_args = serialize_args(flatten_args(args, kwargs))

serialized_runtime_env_info = if !isnothing(runtime_env)
_serialize(RuntimeEnvInfo(runtime_env))
else
""
end

return GC.@preserve args ray_jll._submit_task(fd,
task_args,
serialized_runtime_env_info,
resources)
GC.@preserve task_args begin
return ray_jll._submit_task(fd,
transform_task_args(task_args),
serialized_runtime_env_info,
resources)
end
end

# TODO: be smarter about handling flattened args
# Adapted from `prepare_args_internal`:
# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L673
function prepare_task_args(args)
function serialize_args(args)
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)
record_call_site = ray_jll.record_ref_creation_sites(ray_config)

rpc_address = ray_jll.GetRpcAddress(GetCoreWorker())
worker = GetCoreWorker()
rpc_address = ray_jll.GetRpcAddress(worker)

total_inlined = 0
task_args = StdVector{CxxPtr{ray_jll.TaskArg}}()
task_args = Any[]
for arg in args
# Note: The Python `prepare_args_internal` function checks if the `arg` is an
# `ObjectRef` and in that case uses the object ID to directly make a
# `TaskArgByReference`. However, the `arg` here will always be a `Pair`
# (or a list in Python) so I expect this special case is never used.
# `TaskArgByReference`. However, as the `args` here are flattened the `arg` will
# always be a `Pair` (or a list in Python). I suspect this Python code path just
# dead code so we'll exclude it from ours.

serialized_arg = serialize_to_bytes(arg)
serialized_arg_size = sizeof(serialized_arg)
Expand All @@ -239,18 +242,26 @@ function prepare_task_args(args)
total_inlined += serialized_arg_size
ray_jll.TaskArgByValue(ray_jll.RayObject(buffer))
else
oid = ray_jll.put(buffer)
oid = ray_jll.put(worker, buffer)
# TODO: Add test for populating `call_site`
call_site = record_call_site ? sprint(Base.show_backtrace, backtrace()) : ""
ray_jll.TaskArgByReference(oid, rpc_address, call_site)
end

push!(task_args, CxxRef(task_arg))
push!(task_args, task_arg)
end

return task_args
end

function transform_task_args(task_args)
task_arg_ptrs = StdVector{CxxPtr{ray_jll.TaskArg}}()
for task_arg in task_args
push!(task_arg_ptrs, CxxRef(task_arg))
end
return task_arg_ptrs
end

function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
application_error, is_retryable_error)
returns = ray_jll.cast_to_returns(returns_ptr)
Expand Down
1 change: 1 addition & 0 deletions Ray.jl/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Aqua
using CxxWrap: CxxPtr, StdVector
using Ray
using Serialization
using Test
Expand Down
36 changes: 36 additions & 0 deletions Ray.jl/test/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,41 @@
@test !contains(stderr_logs, "Constructing CoreWorkerProcess")
end
end
end

@testset "serialize_args" begin
ray_config = ray_jll.RayConfigInstance()
put_threshold = ray_jll.max_direct_call_object_size(ray_config)
rpc_inline_threshold = ray_jll.task_rpc_inlined_bytes_limit(ray_config)
serialization_overhead = begin
sizeof(Ray.serialize_to_bytes(:_ => zeros(UInt8, put_threshold))) - put_threshold
end

@testset "put threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
b = :_ => zeros(UInt8, put_threshold - serialization_overhead + 1)
task_args = Ray.serialize_args([a, b])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByValue
@test task_args[2] isa ray_jll.TaskArgByReference

task_args = Ray.serialize_args([b, a])
@test length(task_args) == 2
@test task_args[1] isa ray_jll.TaskArgByReference
@test task_args[2] isa ray_jll.TaskArgByValue
end

@testset "inline threshold" begin
a = :_ => zeros(UInt8, put_threshold - serialization_overhead)
args = fill(a, rpc_inline_threshold ÷ put_threshold + 1)
task_args = Ray.serialize_args(args)
@test all(t -> t isa ray_jll.TaskArgByValue, task_args[1:(end - 1)])
@test task_args[end] isa ray_jll.TaskArgByReference
end
end

@testset "transform_task_args" begin
task_args = Ray.serialize_args(Ray.flatten_args([1, 2, 3], (;)))
result = Ray.transform_task_args(task_args)
@test result isa StdVector{CxxPtr{ray_jll.TaskArg}}
end

0 comments on commit d171b9d

Please sign in to comment.