diff --git a/Ray.jl/src/Ray.jl b/Ray.jl/src/Ray.jl index 78b66cb1..8491ceeb 100644 --- a/Ray.jl/src/Ray.jl +++ b/Ray.jl/src/Ray.jl @@ -14,7 +14,7 @@ using LoggingExtras using Pkg using Serialization -import ray_core_worker_julia_jll as rayjll +import ray_core_worker_julia_jll as ray_jll export start_worker, submit_task, @ray_import diff --git a/Ray.jl/src/function_manager.jl b/Ray.jl/src/function_manager.jl index 9d169b99..c3d5bafc 100644 --- a/Ray.jl/src/function_manager.jl +++ b/Ray.jl/src/function_manager.jl @@ -24,14 +24,11 @@ # gcs client # ~~maybe job id?~~ this is managed by the core worker process -using ray_core_worker_julia_jll: JuliaGcsClient, Exists, Put, Get, - JuliaFunctionDescriptor, function_descriptor - # https://github.com/beacon-biosignals/ray/blob/1c0cddc478fa33d4c244d3c30aba861a77b0def9/python/ray/_private/ray_constants.py#L122-L123 const FUNCTION_SIZE_WARN_THRESHOLD = 10_000_000 # in bytes const FUNCTION_SIZE_ERROR_THRESHOLD = 100_000_000 # in bytes -_mib_string(len) = string(div(len, 1024 * 1024), " MiB") +_mib_string(num_bytes) = string(div(num_bytes, 1024 * 1024), " MiB") # https://github.com/beacon-biosignals/ray/blob/1c0cddc478fa33d4c244d3c30aba861a77b0def9/python/ray/_private/utils.py#L744-L746 const _check_msg = "Check that its definition is not implicitly capturing a large " * "array or other object in scope. Tip: use `Ray.put()` to put large " * @@ -40,12 +37,12 @@ const _check_msg = "Check that its definition is not implicitly capturing a larg function check_oversized_function(serialized, function_descriptor) len = length(serialized) if len > FUNCTION_SIZE_ERROR_THRESHOLD - msg = "The function $(rayjll.CallString(function_descriptor)) is too " * + msg = "The function $(ray_jll.CallString(function_descriptor)) is too " * "large ($(_mib_string(len))); FUNCTION_SIZE_ERROR_THRESHOLD=" * "$(_mib_string(FUNCTION_SIZE_ERROR_THRESHOLD)). " * _check_msg throw(ArgumentError(msg)) elseif len > FUNCTION_SIZE_WARN_THRESHOLD - msg = "The function $(rayjll.CallString(function_descriptor)) is very " * + msg = "The function $(ray_jll.CallString(function_descriptor)) is very " * "large ($(_mib_string(len))). " * _check_msg @warn msg # TODO: push warning message to driver if this is a worker @@ -59,7 +56,7 @@ end const FUNCTION_MANAGER_NAMESPACE = "jlfun" Base.@kwdef struct FunctionManager - gcs_client::JuliaGcsClient + gcs_client::ray_jll.JuliaGcsClient functions::Dict{String,Any} end @@ -67,41 +64,40 @@ const FUNCTION_MANAGER = Ref{FunctionManager}() function _init_global_function_manager(gcs_address) @info "connecting function manager to GCS at $gcs_address..." - gcs_client = JuliaGcsClient(gcs_address) - rayjll.Connect(gcs_client) + gcs_client = ray_jll.JuliaGcsClient(gcs_address) + ray_jll.Connect(gcs_client) FUNCTION_MANAGER[] = FunctionManager(; gcs_client, functions=Dict{String,Any}()) end -function function_key(fd::JuliaFunctionDescriptor, job_id=get_current_job_id()) +function function_key(fd::ray_jll.JuliaFunctionDescriptor, job_id=get_current_job_id()) return string("RemoteFunction:", job_id, ":", fd.function_hash) end function export_function!(fm::FunctionManager, f, job_id=get_current_job_id()) - fd = function_descriptor(f) + fd = ray_jll.function_descriptor(f) key = function_key(fd, job_id) @debug "exporting function to function store:" fd key - if Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, - deepcopy(key), # DFK: I _think_ the string memory may be mangled - # if we don't copy. not sure but it can't hurt - -1) + # DFK: I _think_ the string memory may be mangled if we don't `deepcopy`. Not sure but + # it can't hurt + if ray_jll.Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, deepcopy(key), -1) @debug "function already present in GCS store:" fd key f else @debug "exporting function to GCS store:" fd key f val = base64encode(serialize, f) check_oversized_function(val, fd) - Put(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, val, true, -1) + ray_jll.Put(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, val, true, -1) end end -function wait_for_function(fm::FunctionManager, fd::JuliaFunctionDescriptor, +function wait_for_function(fm::FunctionManager, fd::ray_jll.JuliaFunctionDescriptor, job_id=get_current_job_id(); pollint_s=0.01, timeout_s=10) key = function_key(fd, job_id) status = timedwait(timeout_s; pollint=pollint_s) do # timeout the Exists query to the same timeout we use here so we don't # deadlock. - Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, timeout_s) + ray_jll.Exists(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, timeout_s) end return status end @@ -111,12 +107,12 @@ end # somthing like `eval(Meta.parse(CallString(fd)))`), falling back to the function # store only if needed. # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/60 -function import_function!(fm::FunctionManager, fd::JuliaFunctionDescriptor, +function import_function!(fm::FunctionManager, fd::ray_jll.JuliaFunctionDescriptor, job_id=get_current_job_id()) return get!(fm.functions, fd.function_hash) do key = function_key(fd, job_id) @debug "function not found locally, retrieving from function store" fd key - val = Get(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, -1) + val = ray_jll.Get(fm.gcs_client, FUNCTION_MANAGER_NAMESPACE, key, -1) try io = IOBuffer() iob64 = Base64DecodePipe(io) diff --git a/Ray.jl/src/object_store.jl b/Ray.jl/src/object_store.jl index 7de26b2e..6f931efd 100644 --- a/Ray.jl/src/object_store.jl +++ b/Ray.jl/src/object_store.jl @@ -4,15 +4,7 @@ Store `data` in the object store. Returns an object reference which can used to retrieve the `data` with [`Ray.get`](@ref). """ -function put(data) - bytes = Vector{UInt8}() - io = IOBuffer(bytes; write=true) - serialize(io, data) - buffer_ptr = Ptr{Nothing}(pointer(bytes)) - buffer_size = sizeof(bytes) - buffer = rayjll.LocalMemoryBuffer(buffer_ptr, buffer_size, true) - return rayjll.put(buffer) -end +put(data) = ray_jll.put(to_serialized_buffer(data)) """ Ray.get(object_id::ObjectIDAllocated) @@ -24,13 +16,24 @@ if run in an `@async` task. If the task that generated the `ObjectID` failed with a Julia exception, the captured exception will be thrown on `get`. """ -get(oid::rayjll.ObjectIDAllocated) = _get(take!(rayjll.get(oid))) -get(obj::SharedPtr{rayjll.RayObject}) = _get(take!(rayjll.GetData(obj[]))) +get(oid::ray_jll.ObjectIDAllocated) = _get(ray_jll.get(oid)) +get(obj::SharedPtr{ray_jll.RayObject}) = _get(ray_jll.GetData(obj[])) get(x) = x -function _get(data::Vector{UInt8}) - result = deserialize(IOBuffer(data)) +function _get(buffer) + result = from_serialized_buffer(buffer) # TODO: add an option to not rethrow # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/58 result isa RayRemoteException ? throw(result) : return result end + +function to_serialized_buffer(data) + bytes = Vector{UInt8}() + io = IOBuffer(bytes; write=true) + serialize(io, data) + return ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) +end + +function from_serialized_buffer(buffer) + result = deserialize(IOBuffer(take!(buffer))) +end diff --git a/Ray.jl/src/runtime.jl b/Ray.jl/src/runtime.jl index 280682b8..2eabc2e6 100644 --- a/Ray.jl/src/runtime.jl +++ b/Ray.jl/src/runtime.jl @@ -32,14 +32,14 @@ function Base.showerror(io::IO, re::RayRemoteException) end """ - const GLOBAL_STATE_ACCESSOR::Ref{rayjll.GlobalStateAccessor} + const GLOBAL_STATE_ACCESSOR::Ref{ray_jll.GlobalStateAccessor} Global binding for GCS client interface to access global state information. Currently only used to get the next job ID. This is set during `init` and used there to get the Job ID for the driver. """ -const GLOBAL_STATE_ACCESSOR = Ref{rayjll.GlobalStateAccessor}() +const GLOBAL_STATE_ACCESSOR = Ref{ray_jll.GlobalStateAccessor}() function init(runtime_env::Union{RuntimeEnv,Nothing}=nothing) # XXX: this is at best EXREMELY IMPERFECT check. we should do something @@ -70,19 +70,19 @@ function init(runtime_env::Union{RuntimeEnv,Nothing}=nothing) args = parse_ray_args_from_raylet_out() gcs_address = args[3] - opts = rayjll.GcsClientOptions(gcs_address) - GLOBAL_STATE_ACCESSOR[] = rayjll.GlobalStateAccessor(opts) - rayjll.Connect(GLOBAL_STATE_ACCESSOR[]) || + opts = ray_jll.GcsClientOptions(gcs_address) + GLOBAL_STATE_ACCESSOR[] = ray_jll.GlobalStateAccessor(opts) + ray_jll.Connect(GLOBAL_STATE_ACCESSOR[]) || error("Failed to connect to Ray GCS at $(gcs_address)") - atexit(() -> rayjll.Disconnect(GLOBAL_STATE_ACCESSOR[])) + atexit(() -> ray_jll.Disconnect(GLOBAL_STATE_ACCESSOR[])) - job_id = rayjll.GetNextJobID(GLOBAL_STATE_ACCESSOR[]) + job_id = ray_jll.GetNextJobID(GLOBAL_STATE_ACCESSOR[]) job_config = JobConfig(RuntimeEnvInfo(runtime_env)) serialized_job_config = _serialize(job_config) - rayjll.initialize_driver(args..., job_id, serialized_job_config) - atexit(rayjll.shutdown_driver) + ray_jll.initialize_driver(args..., job_id, serialized_job_config) + atexit(ray_jll.shutdown_driver) _init_global_function_manager(gcs_address) @@ -91,14 +91,14 @@ end # this could go in JLL but if/when global worker is hosted here it's better to # keep it local -get_current_job_id() = rayjll.ToInt(rayjll.GetCurrentJobId()) +get_current_job_id() = ray_jll.ToInt(ray_jll.GetCurrentJobId()) """ get_task_id() -> String Get the current task ID for this worker in hex format. """ -get_task_id() = String(rayjll.Hex(rayjll.GetCurrentTaskId())) +get_task_id() = String(ray_jll.Hex(ray_jll.GetCurrentTaskId())) function parse_ray_args_from_raylet_out() #= @@ -152,12 +152,12 @@ function parse_ray_args_from_raylet_out() return (raylet, store, gcs_address, node_ip, node_port) end -initialize_coreworker_driver(args...) = rayjll.initialize_coreworker_driver(args...) +initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...) function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); runtime_env::Union{RuntimeEnv,Nothing}=nothing) export_function!(FUNCTION_MANAGER[], f, get_current_job_id()) - fd = function_descriptor(f) + fd = ray_jll.function_descriptor(f) arg_oids = map(Ray.put, flatten_args(args, kwargs)) serialized_runtime_env_info = if !isnothing(runtime_env) @@ -166,22 +166,22 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); "" end - return GC.@preserve args rayjll._submit_task(fd, arg_oids, serialized_runtime_env_info) + return GC.@preserve args ray_jll._submit_task(fd, arg_oids, serialized_runtime_env_info) end function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, application_error, is_retryable_error) - returns = rayjll.cast_to_returns(returns_ptr) - task_args = rayjll.cast_to_task_args(task_args_ptr) + returns = ray_jll.cast_to_returns(returns_ptr) + task_args = ray_jll.cast_to_task_args(task_args_ptr) local result try - @info "task_executor: called for JobID $(rayjll.GetCurrentJobId())" - fd = rayjll.GetFunctionDescriptor(ray_function) + @info "task_executor: called for JobID $(ray_jll.GetCurrentJobId())" + fd = ray_jll.GetFunctionDescriptor(ray_function) # TODO: may need to wait for function here... @debug "task_executor: importing function" fd func = import_function!(FUNCTION_MANAGER[], - rayjll.unwrap_function_descriptor(fd), + ray_jll.unwrap_function_descriptor(fd), get_current_job_id()) flattened = map(Ray.get, task_args) @@ -210,10 +210,10 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, # so we use a cpp function whose only job is to assign the value to the # pointer err_msg = sprint(showerror, captured) - status = rayjll.report_error(application_error, err_msg, timestamp) + status = ray_jll.report_error(application_error, err_msg, timestamp) # XXX: we _can_ set _this_ return pointer here for some reason, and it # was _harder_ to toss it back over the fence to the wrapper C++ code - is_retryable_error[] = rayjll.CxxBool(false) + is_retryable_error[] = ray_jll.CxxBool(false) @debug "push error status: $status" result = RayRemoteException(getpid(), task_name, captured) @@ -225,17 +225,14 @@ function task_executor(ray_function, returns_ptr, task_args_ptr, task_name, # TODO: support multiple return values # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/54 - buffer_data = Vector{UInt8}(sprint(serialize, result)) - buffer_size = sizeof(buffer_data) - buffer = rayjll.LocalMemoryBuffer(buffer_data, buffer_size, true) - push!(returns, buffer) + push!(returns, to_serialized_buffer(result)) return nothing end #= julia -e sleep(120) -- \ - /Users/cvogt/.julia/dev/rayjll/venv/lib/python3.10/site-packages/ray/cpp/default_worker \ + /Users/cvogt/.julia/dev/ray_core_worker_julia_jll/venv/lib/python3.10/site-packages/ray/cpp/default_worker \ --ray_plasma_store_socket_name=/tmp/ray/session_2023-08-09_14-14-28_230005_27400/sockets/plasma_store \ --ray_raylet_socket_name=/tmp/ray/session_2023-08-09_14-14-28_230005_27400/sockets/raylet \ --ray_node_manager_port=57236 \ @@ -317,7 +314,7 @@ function start_worker(args=ARGS) @info "Starting Julia worker runtime with args" parsed_args - return rayjll.initialize_worker(parsed_args["raylet_socket"], + return ray_jll.initialize_worker(parsed_args["raylet_socket"], parsed_args["store_socket"], parsed_args["address"], parsed_args["node_ip_address"], diff --git a/Ray.jl/src/runtime_env.jl b/Ray.jl/src/runtime_env.jl index b7db4d88..8885a28b 100644 --- a/Ray.jl/src/runtime_env.jl +++ b/Ray.jl/src/runtime_env.jl @@ -61,7 +61,7 @@ end # https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/57 function _serialize(job_config::JobConfig) job_config_json = JSON3.write(json_dict(job_config)) - return rayjll.serialize_job_config_json(job_config_json) + return ray_jll.serialize_job_config_json(job_config_json) end function process_import_statements(ex::Expr) diff --git a/Ray.jl/test/runtests.jl b/Ray.jl/test/runtests.jl index a959513d..3cadee37 100644 --- a/Ray.jl/test/runtests.jl +++ b/Ray.jl/test/runtests.jl @@ -3,7 +3,7 @@ using Ray using Serialization using Test -import ray_core_worker_julia_jll as rayjll +import ray_core_worker_julia_jll as ray_jll include("setup.jl") include("utils.jl") diff --git a/Ray.jl/test/utils.jl b/Ray.jl/test/utils.jl index 3d4cc9ed..ab23f2ed 100644 --- a/Ray.jl/test/utils.jl +++ b/Ray.jl/test/utils.jl @@ -18,7 +18,7 @@ function setup_core_worker(body) try body() finally - rayjll.shutdown_driver() + ray_jll.shutdown_driver() end end diff --git a/deps/wrapper.cc b/deps/wrapper.cc index 5a316671..7991876d 100644 --- a/deps/wrapper.cc +++ b/deps/wrapper.cc @@ -137,6 +137,41 @@ std::vector> cast_to_task_args(void *ptr) { return *rayobj_ptr; } +ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, + const std::vector &object_ids, + const std::string &serialized_runtime_env_info) { + + auto &worker = CoreWorkerProcess::GetCoreWorker(); + + ray::FunctionDescriptor func_descriptor = std::make_shared(jl_func_descriptor); + RayFunction func(Language::JULIA, func_descriptor); + + std::vector> args; + for (auto & obj_id : object_ids) { + args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/"")); + } + + // TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87 + TaskOptions options; + options.serialized_runtime_env_info = serialized_runtime_env_info; + + rpc::SchedulingStrategy scheduling_strategy; + scheduling_strategy.mutable_default_scheduling_strategy(); + + // https://github.com/ray-project/ray/blob/4e9e8913a6c9cc3533fe27478f30bdee1deffaf5/src/ray/core_worker/test/core_worker_test.cc#L79 + auto return_refs = worker.SubmitTask( + func, + args, + options, + /*max_retries=*/0, + /*retry_exceptions=*/false, + scheduling_strategy, + /*debugger_breakpoint=*/"" + ); + + return ObjectRefsToIds(return_refs)[0]; +} + // TODO: probably makes more sense to have a global worker rather than calling // GetCoreWorker() over and over again...(here and below) // https://github.com/beacon-biosignals/ray_core_worker_julia_jll.jl/issues/61 @@ -273,41 +308,6 @@ bool JuliaGcsClient::Exists(const std::string &ns, return exists; } -ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor, - const std::vector &object_ids, - const std::string &serialized_runtime_env_info) { - - auto &worker = CoreWorkerProcess::GetCoreWorker(); - - ray::FunctionDescriptor func_descriptor = std::make_shared(jl_func_descriptor); - RayFunction func(Language::JULIA, func_descriptor); - - std::vector> args; - for (auto & obj_id : object_ids) { - args.emplace_back(new TaskArgByReference(obj_id, worker.GetRpcAddress(), /*call-site*/"")); - } - - // TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87 - TaskOptions options; - options.serialized_runtime_env_info = serialized_runtime_env_info; - - rpc::SchedulingStrategy scheduling_strategy; - scheduling_strategy.mutable_default_scheduling_strategy(); - - // https://github.com/ray-project/ray/blob/4e9e8913a6c9cc3533fe27478f30bdee1deffaf5/src/ray/core_worker/test/core_worker_test.cc#L79 - auto return_refs = worker.SubmitTask( - func, - args, - options, - /*max_retries=*/0, - /*retry_exceptions=*/false, - scheduling_strategy, - /*debugger_breakpoint=*/"" - ); - - return ObjectRefsToIds(return_refs)[0]; -} - Status report_error(std::string *application_error, const std::string &err_msg, double timestamp) {