Skip to content

Commit

Permalink
support resource requests and use CPU: 1 for default tasks (#66)
Browse files Browse the repository at this point in the history
* silly get/setindex and string-double map for resource requests

* build resource dict and pass via `submit_task`

* merge fail

* tests
  • Loading branch information
kleinschmidt authored Aug 30, 2023
1 parent 40e0f7e commit 5b92046
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
8 changes: 6 additions & 2 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ end
initialize_coreworker_driver(args...) = ray_jll.initialize_coreworker_driver(args...)

function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
runtime_env::Union{RuntimeEnv,Nothing}=nothing)
runtime_env::Union{RuntimeEnv,Nothing}=nothing,
resources::Dict{String,Float64}=Dict("CPU" => 1.0))
export_function!(FUNCTION_MANAGER[], f, get_current_job_id())
fd = ray_jll.function_descriptor(f)
arg_oids = map(Ray.put, flatten_args(args, kwargs))
Expand All @@ -166,7 +167,10 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
""
end

return GC.@preserve args ray_jll._submit_task(fd, arg_oids, serialized_runtime_env_info)
return GC.@preserve args ray_jll._submit_task(fd,
arg_oids,
serialized_runtime_env_info,
resources)
end

function task_executor(ray_function, returns_ptr, task_args_ptr, task_name,
Expand Down
14 changes: 14 additions & 0 deletions Ray.jl/test/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,17 @@ end
pids2 = Ray.get.([submit_task(getpid, ()) for _ in 1:n_tasks])
@test !isempty(intersect(pids, pids2))
end

@testset "task resource requests" begin
function gimme_resources()
resources = Ray.ray_jll.get_task_required_resources()
ks = Ray.ray_jll._keys(resources)
return Dict(String(k) => float(Ray.ray_jll._getindex(resources, k)) for k in ks)
end

default_resources = Ray.get(submit_task(gimme_resources, ()))
@test default_resources["CPU"] == 1.0

custom_resources = Ray.get(submit_task(gimme_resources, (); resources=Dict("CPU" => 0.5)))
@test custom_resources["CPU"] == 0.5
end
30 changes: 29 additions & 1 deletion deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ std::vector<std::shared_ptr<RayObject>> cast_to_task_args(void *ptr) {

ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
const std::vector<ObjectID> &object_ids,
const std::string &serialized_runtime_env_info) {
const std::string &serialized_runtime_env_info,
const std::unordered_map<std::string, double> &resources) {

auto &worker = CoreWorkerProcess::GetCoreWorker();

Expand All @@ -154,6 +155,7 @@ ObjectID _submit_task(const ray::JuliaFunctionDescriptor &jl_func_descriptor,
// TaskOptions: https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/common.h#L62-L87
TaskOptions options;
options.serialized_runtime_env_info = serialized_runtime_env_info;
options.resources = resources;

rpc::SchedulingStrategy scheduling_strategy;
scheduling_strategy.mutable_default_scheduling_strategy();
Expand Down Expand Up @@ -354,6 +356,12 @@ std::string get_job_serialized_runtime_env() {
return job_serialized_runtime_env;
}

std::unordered_map<std::string, double> get_task_required_resources() {
auto &worker = CoreWorkerProcess::GetCoreWorker();
auto &worker_context = worker.GetWorkerContext();
return worker_context.GetCurrentTask()->GetRequiredResources().GetResourceUnorderedMap();
}

namespace jlcxx
{
// Needed for upcasting
Expand Down Expand Up @@ -388,6 +396,25 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
// the function. If you fail to do this you'll get a "No appropriate factory for type" upon
// attempting to use the shared library in Julia.

mod.add_type<std::unordered_map<std::string, double>>("CxxMapStringDouble");
mod.method("_setindex!", [](std::unordered_map<std::string, double> &map,
double val,
std::string key) {
map[key] = val;
return map;
});
mod.method("_getindex", [](std::unordered_map<std::string, double> &map,
std::string key) {
return map[key];
});
mod.method("_keys", [](std::unordered_map<std::string, double> &map) {
std::vector<std::string> keys(map.size());
for (auto kv : map) {
keys.push_back(kv.first);
}
return keys;
});

// TODO: Make `JobID` is a subclass of `BaseID`. The use of templating makes this more work
// than normal.
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/id.h#L106
Expand Down Expand Up @@ -511,4 +538,5 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)

mod.method("serialize_job_config_json", &serialize_job_config_json);
mod.method("get_job_serialized_runtime_env", &get_job_serialized_runtime_env);
mod.method("get_task_required_resources", &get_task_required_resources);
}
18 changes: 15 additions & 3 deletions src/wrappers/any.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,24 @@ function Base.push!(v::CxxPtr{StdVector{T}}, el::T) where T <: SharedPtr{LocalMe
return push!(v, CxxRef(el))
end

# XXX: Need to convert julia vectors to StdVector. This function helps us avoid having
# XXX: Need to convert julia vectors to StdVector and build the
# `std::unordered_map` for resources. This function helps us avoid having
# CxxWrap as a direct dependency in Ray.jl
function _submit_task(fd, oids::AbstractVector, serialized_runtime_env_info)
function _submit_task(fd, oids::AbstractVector, serialized_runtime_env_info, resources)
# https://github.com/JuliaInterop/CxxWrap.jl/issues/367
args = isempty(oids) ? StdVector{ObjectID}() : StdVector(oids)
return _submit_task(fd, args, serialized_runtime_env_info)
@debug "task resources: " resources
resources = build_resource_requests(resources)
return _submit_task(fd, args, serialized_runtime_env_info, resources)
end

# work around lack of wrapped `std::unordered_map`
function build_resource_requests(resources::Dict{<:AbstractString,<:Number})
cpp_resources = CxxMapStringDouble()
for (k, v) in pairs(resources)
_setindex!(cpp_resources, float(v), k)
end
return cpp_resources
end

#####
Expand Down

0 comments on commit 5b92046

Please sign in to comment.