Skip to content

Commit

Permalink
Execute user's callback after setting request status (#32)
Browse files Browse the repository at this point in the history
Once the user-defined callback is executed, it's important that the status for the request is already set, this may be used by the callback itself, as for example in the `ucxx::RequestTagMulti` where the request status must be known to ensure it did result in an error.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - AJ Schmidt (https://github.com/ajschmidt8)

URL: #32
  • Loading branch information
pentschev authored May 2, 2023
1 parent c96d347 commit 4acbcf2
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 107 deletions.
4 changes: 3 additions & 1 deletion build_and_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ run_py_benchmark() {
}

if [[ $RUN_CPP_TESTS != 0 ]]; then
${BINARY_PATH}/gtests/libucxx/UCXX_TEST
# UCX_TCP_CM_REUSEADDR=y to be able to bind immediately to the same port before
# `TIME_WAIT` timeout
UCX_TCP_CM_REUSEADDR=y ${BINARY_PATH}/gtests/libucxx/UCXX_TEST
fi
if [[ $RUN_CPP_BENCH != 0 ]]; then
# run_cpp_benchmark PROGRESS_MODE
Expand Down
4 changes: 2 additions & 2 deletions ci/test_cpp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ print_system_stats
BINARY_PATH=${CONDA_PREFIX}/bin

run_tests() {
CMD_LINE="UCX_TCP_CM_REUSEADDR=y timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-*DelayedSubmission*ProgressTagMulti*:ListenerTest.CloseCallback:ListenerTest.IsAlive:ListenerTest.RaiseOnError"
CMD_LINE="UCX_TCP_CM_REUSEADDR=y timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-ListenerTest.CloseCallback:ListenerTest.IsAlive:ListenerTest.RaiseOnError"

rapids-logger "Running: \n - ${CMD_LINE}"

UCX_TCP_CM_REUSEADDR=y timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-*DelayedSubmission*ProgressTagMulti*:ListenerTest.CloseCallback:ListenerTest.IsAlive:ListenerTest.RaiseOnError
UCX_TCP_CM_REUSEADDR=y timeout 10m ${BINARY_PATH}/gtests/libucxx/UCXX_TEST --gtest_filter=-ListenerTest.CloseCallback:ListenerTest.IsAlive:ListenerTest.RaiseOnError
}

run_benchmark() {
Expand Down
17 changes: 8 additions & 9 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,14 @@ std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> end
size_t length,
const bool enablePythonFuture);

std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData);
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestTagMulti> createRequestTagMultiSend(std::shared_ptr<Endpoint> endpoint,
const std::vector<void*>& buffer,
Expand Down
26 changes: 12 additions & 14 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,12 @@ class Endpoint : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagSend(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagSend(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue a tag receive operation.
Expand All @@ -350,13 +349,12 @@ class Endpoint : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue a multi-buffer tag send operation.
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Request : public Component {
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
std::function<void(std::shared_ptr<void>)> _callback{nullptr}; ///< Completion callback
std::shared_ptr<void> _callbackData{nullptr}; ///< Completion callback data
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
std::shared_ptr<Worker> _worker{
nullptr}; ///< Worker that generated request (if not from endpoint)
std::shared_ptr<Endpoint> _endpoint{
Expand Down
23 changes: 11 additions & 12 deletions cpp/include/ucxx/request_tag.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class RequestTag : public Request {
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

public:
/**
Expand Down Expand Up @@ -85,15 +85,14 @@ class RequestTag : public Request {
*
* @returns The `shared_ptr<ucxx::RequestTag>` object
*/
friend std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData);
friend std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

virtual void populateDelayedSubmission();

Expand Down
8 changes: 6 additions & 2 deletions cpp/include/ucxx/request_tag_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,10 @@ class RequestTagMulti : public std::enable_shared_from_this<RequestTagMulti> {
* which will be later used to evaluate if all frames completed and set the final status
* of the multi-transfer request and the Python future, if enabled.
*
* @param[in] status the status of the request being completed.
* @param[in] request the `ucxx::BufferRequest` object containing a single tag .
*/
void markCompleted(std::shared_ptr<void> request);
void markCompleted(ucs_status_t status, RequestCallbackUserData request);

/**
* @brief Callback to submit request to receive new header or frames.
Expand All @@ -231,9 +232,12 @@ class RequestTagMulti : public std::enable_shared_from_this<RequestTagMulti> {
* containing the `next` flag set, then the next request is another header. Otherwise, the
* next incoming message(s) is(are) frame(s).
*
* When called, the callback receives a single argument, the status of the current request.
*
* @param[in] status the status of the request being completed.
* @throws std::runtime_error if called by a send request.
*/
void callback();
void callback(ucs_status_t status);

/**
* @brief Return the status of the request.
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <functional>
#include <memory>
#include <string>
#include <unordered_map>

Expand Down Expand Up @@ -32,4 +33,7 @@ typedef enum {

typedef std::unordered_map<std::string, std::string> ConfigMap;

typedef std::function<void(ucs_status_t, std::shared_ptr<void>)> RequestCallbackUserFunction;
typedef std::shared_ptr<void> RequestCallbackUserData;

} // namespace ucxx
13 changes: 6 additions & 7 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,12 @@ class Worker : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enableFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enableFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Get the address of the UCX worker object.
Expand Down
26 changes: 12 additions & 14 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,26 +209,24 @@ std::shared_ptr<Request> Endpoint::streamRecv(void* buffer,
createRequestStream(endpoint, false, buffer, length, enablePythonFuture));
}

std::shared_ptr<Request> Endpoint::tagSend(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
std::shared_ptr<Request> Endpoint::tagSend(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestTag(
endpoint, true, buffer, length, tag, enablePythonFuture, callbackFunction, callbackData));
}

std::shared_ptr<Request> Endpoint::tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
std::shared_ptr<Request> Endpoint::tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestTag(
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ bool Request::isCompleted() { return _status != UCS_INPROGRESS; }

void Request::callback(void* request, ucs_status_t status)
{
setStatus(status);

ucxx_trace_req_f(_ownerString.c_str(),
request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);
if (_callback) _callback(status, _callbackData);

ucp_request_free(request);

ucxx_trace("Request completed: %p, handle: %p", this, request);
setStatus(status);
}

void Request::process()
Expand Down Expand Up @@ -138,13 +138,6 @@ void Request::process()
status,
ucs_status_string(status));

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);

if (status != UCS_OK) {
ucxx_error(
"error on %s with status %d (%s)", _operationName.c_str(), status, ucs_status_string(status));
Expand All @@ -153,6 +146,13 @@ void Request::process()
_ownerString.c_str(), _request, _operationName.c_str(), "completed immediately");
}

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(status, _callbackData);

setStatus(status);
}

Expand Down
21 changes: 10 additions & 11 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@

namespace ucxx {

std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr)
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr)
{
return std::shared_ptr<RequestTag>(new RequestTag(endpointOrWorker,
send,
Expand All @@ -39,8 +38,8 @@ RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpointOrWorker,
std::make_shared<DelayedSubmission>(send, buffer, length, tag),
std::string(send ? "tagSend" : "tagRecv"),
Expand Down
Loading

0 comments on commit 4acbcf2

Please sign in to comment.