Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture more operation states and receivers by reference in sender adaptors to use SBO of unique_function #1192

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions libs/pika/async_mpi/include/pika/async_mpi/mpi_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace pika::mpi::experimental::detail {
// adds a request callback to the mpi polling code which will call
// notify_one to wake up a suspended task
template <typename OperationState>
void add_suspend_resume_request_callback(MPI_Request request, OperationState& op_state)
void add_suspend_resume_request_callback(OperationState& op_state)
{
PIKA_ASSERT(op_state.completed == false);
detail::add_request_callback(
Expand All @@ -111,52 +111,54 @@ namespace pika::mpi::experimental::detail {
}
op_state.cond_var.notify_one();
},
request);
op_state.request);
}

// -----------------------------------------------------------------
// handler_method::new_task
// adds a request callback to the mpi polling code which will call
// set_value/error on the receiver
template <typename Receiver>
void add_new_task_request_callback(
MPI_Request request, execution::thread_priority p, Receiver&& receiver)
template <typename OperationState>
void add_new_task_request_callback(OperationState& op_state)
{
detail::add_request_callback(
[receiver = PIKA_MOVE(receiver), p](int status) mutable {
[&op_state](int status) mutable {
PIKA_DETAIL_DP(mpi_tran<5>, debug(str<>("schedule_task_callback")));
if (status != MPI_SUCCESS)
{
ex::set_error(PIKA_FORWARD(Receiver, receiver),
ex::set_error(PIKA_MOVE(op_state.receiver),
std::make_exception_ptr(mpi_exception(status)));
}
else
{
// pass the result onto a new task and invoke the continuation
execution::thread_priority p = use_priority_boost(op_state.mode_flags) ?
execution::thread_priority::boost :
execution::thread_priority::normal;
auto snd0 = ex::transfer_just(default_pool_scheduler(p)) |
ex::then([receiver = PIKA_FORWARD(Receiver, receiver)]() mutable {
ex::then([&op_state]() mutable {
PIKA_DETAIL_DP(mpi_tran<5>, debug(str<>("set_value")));
ex::set_value(PIKA_MOVE(receiver));
ex::set_value(PIKA_MOVE(op_state.receiver));
});
ex::start_detached(PIKA_MOVE(snd0));
}
},
request);
op_state.request);
}
msimberg marked this conversation as resolved.
Show resolved Hide resolved

// -----------------------------------------------------------------
// handler_method::continuation
// adds a request callback to the mpi polling code which will call
// the set_value/set_error helper using the void return signature
template <typename OperationState>
void add_continuation_request_callback(MPI_Request request, OperationState& op_state)
void add_continuation_request_callback(OperationState& op_state)
{
detail::add_request_callback(
[&op_state](int status) mutable {
PIKA_DETAIL_DP(mpi_tran<5>, debug(str<>("callback_void")));
set_value_error_helper(status, PIKA_MOVE(op_state.receiver));
},
request);
op_state.request);
}

// -----------------------------------------------------------------
Expand Down
8 changes: 3 additions & 5 deletions libs/pika/async_mpi/include/pika/async_mpi/trigger_mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ namespace pika::mpi::experimental::detail {
// the callback will resume _this_ thread
{
std::unique_lock l{r.op_state.mutex};
add_suspend_resume_request_callback(
r.op_state.request, r.op_state);
add_suspend_resume_request_callback(r.op_state);
if (use_priority_boost(r.op_state.mode_flags))
{
threads::detail::thread_data::scoped_thread_priority
Expand All @@ -185,15 +184,14 @@ namespace pika::mpi::experimental::detail {
{
// The callback will call set_value/set_error inside a new task
// and execution will continue on that thread
add_new_task_request_callback(
r.op_state.request, p, PIKA_MOVE(r.op_state.receiver));
add_new_task_request_callback(r.op_state);
break;
}
case handler_method::continuation:
{
// The callback will call set_value/set_error
// execution will continue on the callback thread
add_continuation_request_callback<>(r.op_state.request, r.op_state);
add_continuation_request_callback(r.op_state);
break;
}
case handler_method::mpix_continuation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,7 @@ namespace pika::ensure_started_detail {
}

template <typename Receiver>
void add_continuation(Receiver& receiver) = delete;

template <typename Receiver>
void add_continuation(Receiver&& receiver)
void add_continuation(Receiver& receiver)
{
PIKA_ASSERT(!continuation.has_value());

Expand Down Expand Up @@ -341,11 +338,10 @@ namespace pika::ensure_started_detail {
// continuation. This has to be done while holding
// the lock since predecessor signalling completion
// may otherwise not see the continuation.
continuation.emplace(
[this, receiver = PIKA_FORWARD(Receiver, receiver)]() mutable {
pika::detail::visit(
stopped_error_value_visitor<Receiver>{receiver}, PIKA_MOVE(v));
});
continuation.emplace([this, &receiver]() mutable {
pika::detail::visit(
stopped_error_value_visitor<Receiver>{receiver}, PIKA_MOVE(v));
});
}
}
}
Expand Down Expand Up @@ -420,7 +416,7 @@ namespace pika::ensure_started_detail {
friend void tag_invoke(
pika::execution::experimental::start_t, operation_state& os) noexcept
{
os.state->add_continuation(PIKA_MOVE(os.receiver));
os.state->add_continuation(os.receiver);
}
};

Expand Down
15 changes: 5 additions & 10 deletions libs/pika/execution/include/pika/execution/algorithms/split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,7 @@ namespace pika::split_detail {
}

template <typename Receiver>
void add_continuation(Receiver& receiver) = delete;

template <typename Receiver>
void add_continuation(Receiver&& receiver)
void add_continuation(Receiver& receiver)
{
if (predecessor_done)
{
Expand Down Expand Up @@ -341,11 +338,9 @@ namespace pika::split_detail {
// to the vector and the vector is not threadsafe in
// itself. The continuation will be called later
// when set_error/set_stopped/set_value is called.
continuations.emplace_back(
[this, receiver = PIKA_FORWARD(Receiver, receiver)]() mutable {
pika::detail::visit(
stopped_error_value_visitor<Receiver>{receiver}, v);
});
continuations.emplace_back([this, &receiver]() mutable {
pika::detail::visit(stopped_error_value_visitor<Receiver>{receiver}, v);
});
}
}
}
Expand Down Expand Up @@ -419,7 +414,7 @@ namespace pika::split_detail {
pika::execution::experimental::start_t, operation_state& os) noexcept
{
os.state->start();
os.state->add_continuation(PIKA_MOVE(os.receiver));
os.state->add_continuation(os.receiver);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ namespace pika::split_tuple_detail {
}

template <std::size_t Index, typename Receiver>
void add_continuation(Receiver& receiver) = delete;

template <std::size_t Index, typename Receiver>
void add_continuation(Receiver&& receiver)
void add_continuation(Receiver& receiver)
{
if (predecessor_done)
{
Expand Down Expand Up @@ -320,11 +317,10 @@ namespace pika::split_tuple_detail {
// to the vector and the vector is not threadsafe in
// itself. The continuation will be called later
// when set_error/set_stopped/set_value is called.
continuations[Index] =
[this, receiver = PIKA_FORWARD(Receiver, receiver)]() mutable {
pika::detail::visit(
stopped_error_value_visitor<Index, Receiver>{receiver}, v);
};
continuations[Index] = [this, &receiver]() mutable {
pika::detail::visit(
stopped_error_value_visitor<Index, Receiver>{receiver}, v);
};
}
}
}
Expand Down Expand Up @@ -457,7 +453,7 @@ namespace pika::split_tuple_detail {
pika::execution::experimental::start_t, operation_state& os) noexcept
{
os.state->start();
os.state->template add_continuation<Index>(PIKA_MOVE(os.receiver));
os.state->template add_continuation<Index>(os.receiver);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ namespace pika::execution::experimental {
pika::detail::try_catch_exception_ptr(
[&]() {
os.scheduler.execute(
[receiver = PIKA_MOVE(os.receiver)]() mutable {
pika::execution::experimental::set_value(PIKA_MOVE(receiver));
[&os]() mutable {
pika::execution::experimental::set_value(PIKA_MOVE(os.receiver));
},
os.fallback_annotation);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,16 @@ namespace pika::execution::experimental {
"async_rw_lock::sender::operation_state state is empty, was the sender "
"already started?");

auto continuation = [r = PIKA_MOVE(os.r)](shared_state_ptr_type state) mutable {
auto continuation = [&os](shared_state_ptr_type state) mutable {
try
{
pika::execution::experimental::set_value(
PIKA_MOVE(r), access_type{PIKA_MOVE(state)});
PIKA_MOVE(os.r), access_type{PIKA_MOVE(state)});
}
catch (...)
{
pika::execution::experimental::set_error(
PIKA_MOVE(r), std::current_exception());
PIKA_MOVE(os.r), std::current_exception());
}
};

Expand Down Expand Up @@ -632,16 +632,16 @@ namespace pika::execution::experimental {
"async_rw_lock::sender::operation_state state is empty, was the sender "
"already started?");

auto continuation = [r = PIKA_MOVE(os.r)](shared_state_ptr_type state) mutable {
auto continuation = [&os](shared_state_ptr_type state) mutable {
try
{
pika::execution::experimental::set_value(
PIKA_MOVE(r), access_type{PIKA_MOVE(state)});
PIKA_MOVE(os.r), access_type{PIKA_MOVE(state)});
}
catch (...)
{
pika::execution::experimental::set_error(
PIKA_MOVE(r), std::current_exception());
PIKA_MOVE(os.r), std::current_exception());
}
};

Expand Down
Loading