Skip to content

Commit

Permalink
Grab bag of minor changes (#65)
Browse files Browse the repository at this point in the history
* Create internal _put function

* Move _submit_task closer to task_executor_callback

* Rename rayjll to ray_jll

* fixup! Rename rayjll to ray_jll

* Rename arg to _mib_string

* Rename _put to to_serialized_buffer
  • Loading branch information
omus authored Aug 29, 2023
1 parent 389087c commit 40e0f7e
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Ray.jl/src/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LoggingExtras
using Pkg
using Serialization

import ray_core_worker_julia_jll as rayjll
import ray_core_worker_julia_jll as ray_jll

export start_worker, submit_task, @ray_import

Expand Down
36 changes: 16 additions & 20 deletions Ray.jl/src/function_manager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@
# gcs client
# ~~maybe job id?~~ this is managed by the core worker process

using ray_core_worker_julia_jll: JuliaGcsClient, Exists, Put, Get,
JuliaFunctionDescriptor, function_descriptor

# https://github.com/beacon-biosignals/ray/blob/1c0cddc478fa33d4c244d3c30aba861a77b0def9/python/ray/_private/ray_constants.py#L122-L123
const FUNCTION_SIZE_WARN_THRESHOLD = 10_000_000 # in bytes
const FUNCTION_SIZE_ERROR_THRESHOLD = 100_000_000 # in bytes

_mib_string(len) = string(div(len, 1024 * 1024), " MiB")
_mib_string(num_bytes) = string(div(num_bytes, 1024 * 1024), " MiB")
# https://github.com/beacon-biosignals/ray/blob/1c0cddc478fa33d4c244d3c30aba861a77b0def9/python/ray/_private/utils.py#L744-L746
const _check_msg = "Check that its definition is not implicitly capturing a large " *
"array or other object in scope. Tip: use `Ray.put()` to put large " *
Expand All @@ -40,12 +37,12 @@ const _check_msg = "Check that its definition is not implicitly capturing a larg
function check_oversized_function(serialized, function_descriptor)
len = length(serialized)
if len > FUNCTION_SIZE_ERROR_THRESHOLD
msg = "The function $(rayjll.CallString(function_descriptor)) is too " *
msg = "The function $(ray_jll.CallString(function_descriptor)) is too " *
"large ($(_mib_string(len))); FUNCTION_SIZE_ERROR_THRESHOLD=" *
"$(_mib_string(FUNCTION_SIZE_ERROR_THRESHOLD)). " * _check_msg
throw(ArgumentError(msg))
elseif len > FUNCTION_SIZE_WARN_THRESHOLD
msg = "The function $(rayjll.CallString(function_descriptor)) is very " *
msg = "The function $(ray_jll.CallString(function_descriptor)) is very " *
"large ($(_mib_string(len))). " * _check_msg
@warn msg
# TODO: push warning message to driver if this is a worker
Expand All @@ -59,49 +56,48 @@ end
const FUNCTION_MANAGER_NAMESPACE = "jlfun"

Base.@kwdef struct FunctionManager
gcs_client::JuliaGcsClient
gcs_client::ray_jll.JuliaGcsClient
functions::Dict{String,Any}
end

const FUNCTION_MANAGER = Ref{FunctionManager}()

function _init_global_function_manager(gcs_address)
@info "connecting function manager to GCS at $gcs_address..."
gcs_client = JuliaGcsClient(gcs_address)
rayjll.Connect(gcs_client)
gcs_client = ray_jll.JuliaGcsClient(gcs_address)
ray_jll.Connect(gcs_client)
FUNCTION_MANAGER[] = FunctionManager(; gcs_client,
functions=Dict{String,Any}())
end

function function_key(fd::JuliaFunctionDescriptor, job_id=get_current_job_id())
function function_key(fd::ray_jll.JuliaFunctionDescriptor, job_id=get_current_job_id())
return string("RemoteFunction:", job_id, ":", fd.function_hash)
end

function export_function!(fm::FunctionManager, f, job_id=get_current_job_id())
fd = function_descriptor(f)
fd = ray_jll.function_descriptor(f)
key = function_key(fd, job_id)
@debug "exporting function to function store:" fd key
if Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE,
deepcopy(key), # DFK: I _think_ the string memory may be mangled
# if we don't copy. not sure but it can't hurt
-1)
# DFK: I _think_ the string memory may be mangled if we don't `deepcopy`. Not sure but
# it can't hurt
if ray_jll.Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, deepcopy(key), -1)
@debug "function already present in GCS store:" fd key f
else
@debug "exporting function to GCS store:" fd key f
val = base64encode(serialize, f)
check_oversized_function(val, fd)
Put(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, val, true, -1)
ray_jll.Put(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, val, true, -1)
end
end

function wait_for_function(fm::FunctionManager, fd::JuliaFunctionDescriptor,
function wait_for_function(fm::FunctionManager, fd::ray_jll.JuliaFunctionDescriptor,
job_id=get_current_job_id();
pollint_s=0.01, timeout_s=10)
key = function_key(fd, job_id)
status = timedwait(timeout_s; pollint=pollint_s) do
# timeout the Exists query to the same timeout we use here so we don't
# deadlock.
Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, timeout_s)
ray_jll.Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, timeout_s)
end
return status
end
Expand All @@ -111,12 +107,12 @@ end
# somthing like `eval(Meta.parse(CallString(fd)))`), falling back to the function
# store only if needed.
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/60
function import_function!(fm::FunctionManager, fd::JuliaFunctionDescriptor,
function import_function!(fm::FunctionManager, fd::ray_jll.JuliaFunctionDescriptor,
job_id=get_current_job_id())
return get!(fm.functions, fd.function_hash) do
key = function_key(fd, job_id)
@debug "function not found locally, retrieving from function store" fd key
val = Get(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, -1)
val = ray_jll.Get(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, -1)
try
io = IOBuffer()
iob64 = Base64DecodePipe(io)
Expand Down
29 changes: 16 additions & 13 deletions Ray.jl/src/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
Store `data` in the object store. Returns an object reference which can used to retrieve
the `data` with [`Ray.get`](@ref).
"""
function put(data)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
serialize(io, data)
buffer_ptr = Ptr{Nothing}(pointer(bytes))
buffer_size = sizeof(bytes)
buffer = rayjll.LocalMemoryBuffer(buffer_ptr, buffer_size, true)
return rayjll.put(buffer)
end
put(data) = ray_jll.put(to_serialized_buffer(data))

"""
Ray.get(object_id::ObjectIDAllocated)
Expand All @@ -24,13 +16,24 @@ if run in an `@async` task.
If the task that generated the `ObjectID` failed with a Julia exception, the
captured exception will be thrown on `get`.
"""
get(oid::rayjll.ObjectIDAllocated) = _get(take!(rayjll.get(oid)))
get(obj::SharedPtr{rayjll.RayObject}) = _get(take!(rayjll.GetData(obj[])))
get(oid::ray_jll.ObjectIDAllocated) = _get(ray_jll.get(oid))
get(obj::SharedPtr{ray_jll.RayObject}) = _get(ray_jll.GetData(obj[]))
get(x) = x

function _get(data::Vector{UInt8})
result = deserialize(IOBuffer(data))
function _get(buffer)
result = from_serialized_buffer(buffer)
# TODO: add an option to not rethrow
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/58
result isa RayRemoteException ? throw(result) : return result
end

function to_serialized_buffer(data)
bytes = Vector{UInt8}()
io = IOBuffer(bytes; write=true)
serialize(io, data)
return ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
end

function from_serialized_buffer(buffer)
result = deserialize(IOBuffer(take!(buffer)))
end
51 changes: 24 additions & 27 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ function Base.showerror(io::IO, re::RayRemoteException)
end

"""
const GLOBAL_STATE_ACCESSOR::Ref{rayjll.GlobalStateAccessor}
const GLOBAL_STATE_ACCESSOR::Ref{ray_jll.GlobalStateAccessor}
Global binding for GCS client interface to access global state information.
Currently only used to get the next job ID.
This is set during `init` and used there to get the Job ID for the driver.
"""
const GLOBAL_STATE_ACCESSOR = Ref{rayjll.GlobalStateAccessor}()
const GLOBAL_STATE_ACCESSOR = Ref{ray_jll.GlobalStateAccessor}()

function init(runtime_env::Union{RuntimeEnv,Nothing}=nothing)
# XXX: this is at best EXREMELY IMPERFECT check. we should do something
Expand Down Expand Up @@ -70,19 +70,19 @@ function init(runtime_env::Union{RuntimeEnv,Nothing}=nothing)
args = parse_ray_args_from_raylet_out()
gcs_address = args[3]

opts = rayjll.GcsClientOptions(gcs_address)
GLOBAL_STATE_ACCESSOR[] = rayjll.GlobalStateAccessor(opts)
rayjll.Connect(GLOBAL_STATE_ACCESSOR[]) ||
opts = ray_jll.GcsClientOptions(gcs_address)
GLOBAL_STATE_ACCESSOR[] = ray_jll.GlobalStateAccessor(opts)
ray_jll.Connect(GLOBAL_STATE_ACCESSOR[]) ||
error("Failed to connect to Ray GCS at $(gcs_address)")
atexit(() -> rayjll.Disconnect(GLOBAL_STATE_ACCESSOR[]))
atexit(() -> ray_jll.Disconnect(GLOBAL_STATE_ACCESSOR[]))

job_id = rayjll.GetNextJobID(GLOBAL_STATE_ACCESSOR[])
job_id = ray_jll.GetNextJobID(GLOBAL_STATE_ACCESSOR[])

job_config = JobConfig(RuntimeEnvInfo(runtime_env))
serialized_job_config = _serialize(job_config)

rayjll.initialize_driver(args..., job_id, serialized_job_config)
atexit(rayjll.shutdown_driver)
ray_jll.initialize_driver(args..., job_id, serialized_job_config)
atexit(ray_jll.shutdown_driver)

_init_global_function_manager(gcs_address)

Expand All @@ -91,14 +91,14 @@ end

# this could go in JLL but if/when global worker is hosted here it's better to
# keep it local
get_current_job_id() = rayjll.ToInt(rayjll.GetCurrentJobId())
get_current_job_id() = ray_jll.ToInt(ray_jll.GetCurrentJobId())

"""
get_task_id() -> String
Get the current task ID for this worker in hex format.
"""
get_task_id() = String(rayjll.Hex(rayjll.GetCurrentTaskId()))
get_task_id() = String(ray_jll.Hex(ray_jll.GetCurrentTaskId()))

function parse_ray_args_from_raylet_out()
#=
Expand Down Expand Up @@ -152,12 +152,12 @@ function parse_ray_args_from_raylet_out()
return (raylet, store, gcs_address, node_ip, node_port)
end

initialize_coreworker_driver(args...) = rayjll.initialize_coreworker_driver(args...)
initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...)

function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
runtime_env::Union{RuntimeEnv,Nothing}=nothing)
export_function!(FUNCTION_MANAGER[], f, get_current_job_id())
fd = function_descriptor(f)
fd = ray_jll.function_descriptor(f)
arg_oids = map(Ray.put, flatten_args(args, kwargs))

serialized_runtime_env_info = if !isnothing(runtime_env)
Expand All @@ -166,22 +166,22 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
""
end

return GC.@preserve args rayjll._submit_task(fd, arg_oids, serialized_runtime_env_info)
return GC.@preserve args ray_jll._submit_task(fd, arg_oids, serialized_runtime_env_info)
end

function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
application_error, is_retryable_error)
returns = rayjll.cast_to_returns(returns_ptr)
task_args = rayjll.cast_to_task_args(task_args_ptr)
returns = ray_jll.cast_to_returns(returns_ptr)
task_args = ray_jll.cast_to_task_args(task_args_ptr)

local result
try
@info "task_executor: called for JobID $(rayjll.GetCurrentJobId())"
fd = rayjll.GetFunctionDescriptor(ray_function)
@info "task_executor: called for JobID $(ray_jll.GetCurrentJobId())"
fd = ray_jll.GetFunctionDescriptor(ray_function)
# TODO: may need to wait for function here...
@debug "task_executor: importing function" fd
func = import_function!(FUNCTION_MANAGER[],
rayjll.unwrap_function_descriptor(fd),
ray_jll.unwrap_function_descriptor(fd),
get_current_job_id())

flattened = map(Ray.get, task_args)
Expand Down Expand Up @@ -210,10 +210,10 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
# so we use a cpp function whose only job is to assign the value to the
# pointer
err_msg = sprint(showerror, captured)
status = rayjll.report_error(application_error, err_msg, timestamp)
status = ray_jll.report_error(application_error, err_msg, timestamp)
# XXX: we _can_ set _this_ return pointer here for some reason, and it
# was _harder_ to toss it back over the fence to the wrapper C++ code
is_retryable_error[] = rayjll.CxxBool(false)
is_retryable_error[] = ray_jll.CxxBool(false)
@debug "push error status: $status"

result = RayRemoteException(getpid(), task_name, captured)
Expand All @@ -225,17 +225,14 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,

# TODO: support multiple return values
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54
buffer_data = Vector{UInt8}(sprint(serialize, result))
buffer_size = sizeof(buffer_data)
buffer = rayjll.LocalMemoryBuffer(buffer_data, buffer_size, true)
push!(returns, buffer)
push!(returns, to_serialized_buffer(result))

return nothing
end

#=
julia -e sleep(120) -- \
/Users/cvogt/.julia/dev/rayjll/venv/lib/python3.10/site-packages/ray/cpp/default_worker \
/Users/cvogt/.julia/dev/ray_core_worker_julia_jll/venv/lib/python3.10/site-packages/ray/cpp/default_worker \
--ray_plasma_store_socket_name=/tmp/ray/session_2023-08-09_14-14-28_230005_27400/sockets/plasma_store \
--ray_raylet_socket_name=/tmp/ray/session_2023-08-09_14-14-28_230005_27400/sockets/raylet \
--ray_node_manager_port=57236 \
Expand Down Expand Up @@ -317,7 +314,7 @@ function start_worker(args=ARGS)

@info "Starting Julia worker runtime with args" parsed_args

return rayjll.initialize_worker(parsed_args["raylet_socket"],
return ray_jll.initialize_worker(parsed_args["raylet_socket"],
parsed_args["store_socket"],
parsed_args["address"],
parsed_args["node_ip_address"],
Expand Down
2 changes: 1 addition & 1 deletion Ray.jl/src/runtime_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
# https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/57
function _serialize(job_config::JobConfig)
job_config_json = JSON3.write(json_dict(job_config))
return rayjll.serialize_job_config_json(job_config_json)
return ray_jll.serialize_job_config_json(job_config_json)
end

function process_import_statements(ex::Expr)
Expand Down
2 changes: 1 addition & 1 deletion Ray.jl/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Ray
using Serialization
using Test

import ray_core_worker_julia_jll as rayjll
import ray_core_worker_julia_jll as ray_jll

include("setup.jl")
include("utils.jl")
Expand Down
2 changes: 1 addition & 1 deletion Ray.jl/test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function setup_core_worker(body)
try
body()
finally
rayjll.shutdown_driver()
ray_jll.shutdown_driver()
end
end

Expand Down
Loading

0 comments on commit 40e0f7e

Please sign in to comment.