Skip to content

Commit

Permalink
Working submit_task
Browse files Browse the repository at this point in the history
  • Loading branch information
omus committed Aug 31, 2023
1 parent d18b8b0 commit c6b0298
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 51 deletions.
17 changes: 7 additions & 10 deletions Ray.jl/src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,14 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
fd = ray_jll.function_descriptor(f)
task_args = prepare_task_args(flatten_args(args, kwargs))

task_args_alt = StdVector{CxxPtr{ray_jll.TaskArg}}()
for task_arg in task_args
push!(task_args_alt, CxxRef(task_arg))
end

serialized_runtime_env_info = if !isnothing(runtime_env)
_serialize(RuntimeEnvInfo(runtime_env))
else
""
end

return GC.@preserve args ray_jll._submit_task(fd,
task_args_alt,
task_args,
serialized_runtime_env_info,
resources)
end
Expand All @@ -198,7 +193,7 @@ function prepare_task_args(args)

# TODO: put_arg_call_site

task_args = []
task_args = StdVector{CxxPtr{ray_jll.TaskArg}}()
for arg in args
# if arg isa ObjectRef
# oid = arg.oid
Expand All @@ -216,16 +211,18 @@ function prepare_task_args(args)
serialized_arg_size = sizeof(serialized_arg)
buffer = ray_jll.LocalMemoryBuffer(serialized_arg, serialized_arg_size, true)

if (serialized_arg_size <= put_threshold &&
task_arg = if (serialized_arg_size <= put_threshold &&
serialized_arg_size + total_inlined <= rpc_inline_threshold)

push!(task_args, ray_jll.TaskArgByValue(ray_jll.RayObject(buffer)))
total_inlined += serialized_arg_size
ray_jll.TaskArgByValue(ray_jll.RayObject(buffer))
else
oid = ray_jll.put(buffer)
call_site = ""
push!(task_args, ray_jll.TaskArgByReference(oid, rpc_address, call_site))
ray_jll.TaskArgByReference(oid, rpc_address, call_site)
end

push!(task_args, CxxRef(task_arg))
end

return task_args
Expand Down
53 changes: 12 additions & 41 deletions deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,19 +373,10 @@ ray::core::CoreWorker &GetCoreWorker() {
return CoreWorkerProcess::GetCoreWorker();
}

size_t demo(const std::vector<TaskArg *> &vector) {
std::vector<std::unique_ptr<TaskArg>> args;
for (auto arg : vector) {
args.emplace_back(arg);
}
return args.size();
}

void _push_back(std::vector<TaskArg *> &vector, TaskArg &el) {
vector.push_back(&el);
}


namespace jlcxx
{
// Needed for upcasting
Expand All @@ -400,8 +391,6 @@ namespace jlcxx
template<> struct DefaultConstructible<RayObject> : std::false_type {};
// template<> struct DefaultConstructible<JuliaFunctionDescriptor> : std::false_type {};
template<> struct DefaultConstructible<TaskArg> : std::false_type {};
// template<> struct DefaultConstructible<TaskArgByReference> : std::false_type {};
// template<> struct DefaultConstructible<TaskArgByValue> : std::false_type {};

// Custom finalizer to show what is being deleted. Can be useful in tracking down
// segmentation faults due to double deallocations
Expand Down Expand Up @@ -616,42 +605,24 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)

// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/common/task/task_util.h
mod.add_type<TaskArg>("TaskArg");
// .method("ToProto", &TaskArg::ToProto);
jlcxx::stl::apply_stl<std::shared_ptr<TaskArg>>(mod);
// jlcxx::stl::apply_stl<std::unique_ptr<std::string>>(mod);
mod.method("_push_back", &_push_back);

// The Julia types `TaskArgByReference` and `TaskArgByValue` have their default finalizers
// disabled as these will later be used as `std::unique_ptr`. If these finalizers were enabled
// we would see segmentation faults due to double deallocations.
//
// Note: It is possible to create `std::unique_ptr`s in C++ and return them to Julia however
// CxxWrap is unable to compile any wrapped functions using `std::vector<std::unique_ptr<TaskArg>>`.
// We're working around this by using `std::vector<TaskArg *>`.
// https://github.com/JuliaInterop/CxxWrap.jl/issues/370

mod.add_type<TaskArgByReference>("TaskArgByReference", jlcxx::julia_base_type<TaskArg>())
.constructor<const ObjectID &/*object_id*/,
const rpc::Address &/*owner_address*/,
const std::string &/*call_site*/>();
mod.method("SharedPtrTaskArgByReference", [] (
const ObjectID &object_id,
const rpc::Address &owner_address,
const std::string &call_site) {

return std::make_shared<TaskArgByReference>(object_id, owner_address, call_site);
});
// mod.method("TaskArgByReference", [] (
// const ObjectID &object_id,
// const rpc::Address &owner_address,
// const std::string &call_site) {

// return std::make_unique<TaskArgByReference>(object_id, owner_address, call_site);
// });
// jlcxx::stl::apply_stl<TaskArgByReference>(mod);
const std::string &/*call_site*/>(false);

mod.add_type<TaskArgByValue>("TaskArgByValue", jlcxx::julia_base_type<TaskArg>())
.constructor<const std::shared_ptr<RayObject> &/*value*/>();
mod.method("SharedPtrTaskArgByValue", [] (const std::shared_ptr<RayObject> &value) {
return std::make_shared<TaskArgByValue>(value);
});
// mod.method("TaskArgByValue", [] (const std::shared_ptr<RayObject> &value) {
// return std::make_unique<TaskArgByValue>(value);
// });
// jlcxx::stl::apply_stl<TaskArgByValue>(mod);
.constructor<const std::shared_ptr<RayObject> &/*value*/>(false);

mod.method("_submit_task", &_submit_task);

mod.method("demo", &demo);
mod.method("_push_back", &_push_back);
}

0 comments on commit c6b0298

Please sign in to comment.