Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute user's callback after setting request status #32

Merged
merged 10 commits into from
May 2, 2023
4 changes: 3 additions & 1 deletion build_and_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ run_py_benchmark() {
}

if [[ $RUN_CPP_TESTS != 0 ]]; then
${BINARY_PATH}/gtests/libucxx/UCXX_TEST
# UCX_TCP_CM_REUSEADDR=y to be able to bind immediately to the same port before
# `TIME_WAIT` timeout
UCX_TCP_CM_REUSEADDR=y ${BINARY_PATH}/gtests/libucxx/UCXX_TEST
fi
if [[ $RUN_CPP_BENCH != 0 ]]; then
# run_cpp_benchmark PROGRESS_MODE
Expand Down
17 changes: 8 additions & 9 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,14 @@ std::shared_ptr<RequestStream> createRequestStream(std::shared_ptr<Endpoint> end
size_t length,
const bool enablePythonFuture);

std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData);
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestTagMulti> createRequestTagMultiSend(std::shared_ptr<Endpoint> endpoint,
const std::vector<void*>& buffer,
Expand Down
26 changes: 12 additions & 14 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,12 @@ class Endpoint : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagSend(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagSend(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue a tag receive operation.
Expand All @@ -350,13 +349,12 @@ class Endpoint : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Enqueue a multi-buffer tag send operation.
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Request : public Component {
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
std::function<void(std::shared_ptr<void>)> _callback{nullptr}; ///< Completion callback
std::shared_ptr<void> _callbackData{nullptr}; ///< Completion callback data
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
std::shared_ptr<Worker> _worker{
nullptr}; ///< Worker that generated request (if not from endpoint)
std::shared_ptr<Endpoint> _endpoint{
Expand Down
23 changes: 11 additions & 12 deletions cpp/include/ucxx/request_tag.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class RequestTag : public Request {
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

public:
/**
Expand Down Expand Up @@ -85,15 +85,14 @@ class RequestTag : public Request {
*
* @returns The `shared_ptr<ucxx::RequestTag>` object
*/
friend std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData);
friend std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

virtual void populateDelayedSubmission();

Expand Down
8 changes: 6 additions & 2 deletions cpp/include/ucxx/request_tag_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,10 @@ class RequestTagMulti : public std::enable_shared_from_this<RequestTagMulti> {
* which will be later used to evaluate if all frames completed and set the final status
* of the multi-transfer request and the Python future, if enabled.
*
pentschev marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] status the status of the request being completed.
* @param[in] request the `ucxx::BufferRequest` object containing a single tag .
*/
void markCompleted(std::shared_ptr<void> request);
void markCompleted(ucs_status_t status, RequestCallbackUserData request);

/**
* @brief Callback to submit request to receive new header or frames.
Expand All @@ -231,9 +232,12 @@ class RequestTagMulti : public std::enable_shared_from_this<RequestTagMulti> {
* containing the `next` flag set, then the next request is another header. Otherwise, the
* next incoming message(s) is(are) frame(s).
*
pentschev marked this conversation as resolved.
Show resolved Hide resolved
* When called, the callback receives a single argument, the status of the current request.
*
* @param[in] status the status of the request being completed.
* @throws std::runtime_error if called by a send request.
*/
void callback();
void callback(ucs_status_t status);

/**
* @brief Return the status of the request.
Expand Down
4 changes: 4 additions & 0 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <functional>
#include <memory>
#include <string>
#include <unordered_map>

Expand Down Expand Up @@ -32,4 +33,7 @@ typedef enum {

typedef std::unordered_map<std::string, std::string> ConfigMap;

typedef std::function<void(ucs_status_t, std::shared_ptr<void>)> RequestCallbackUserFunction;
typedef std::shared_ptr<void> RequestCallbackUserData;

} // namespace ucxx
13 changes: 6 additions & 7 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,12 @@ class Worker : public Component {
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enableFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr);
std::shared_ptr<Request> tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enableFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Get the address of the UCX worker object.
Expand Down
26 changes: 12 additions & 14 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,26 +209,24 @@ std::shared_ptr<Request> Endpoint::streamRecv(void* buffer,
createRequestStream(endpoint, false, buffer, length, enablePythonFuture));
}

std::shared_ptr<Request> Endpoint::tagSend(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
std::shared_ptr<Request> Endpoint::tagSend(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestTag(
endpoint, true, buffer, length, tag, enablePythonFuture, callbackFunction, callbackData));
}

std::shared_ptr<Request> Endpoint::tagRecv(
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
std::shared_ptr<Request> Endpoint::tagRecv(void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
{
auto endpoint = std::dynamic_pointer_cast<Endpoint>(shared_from_this());
return registerInflightRequest(createRequestTag(
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ bool Request::isCompleted() { return _status != UCS_INPROGRESS; }

void Request::callback(void* request, ucs_status_t status)
{
setStatus(status);

ucxx_trace_req_f(_ownerString.c_str(),
request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);
if (_callback) _callback(status, _callbackData);

ucp_request_free(request);

ucxx_trace("Request completed: %p, handle: %p", this, request);
setStatus(status);
}

void Request::process()
Expand Down Expand Up @@ -138,13 +138,6 @@ void Request::process()
status,
ucs_status_string(status));

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(_callbackData);

if (status != UCS_OK) {
ucxx_error(
"error on %s with status %d (%s)", _operationName.c_str(), status, ucs_status_string(status));
Expand All @@ -153,6 +146,13 @@ void Request::process()
_ownerString.c_str(), _request, _operationName.c_str(), "completed immediately");
}

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(status, _callbackData);

setStatus(status);
}

Expand Down
21 changes: 10 additions & 11 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@

namespace ucxx {

std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
std::function<void(std::shared_ptr<void>)> callbackFunction = nullptr,
std::shared_ptr<void> callbackData = nullptr)
std::shared_ptr<RequestTag> createRequestTag(std::shared_ptr<Component> endpointOrWorker,
bool send,
void* buffer,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr)
{
return std::shared_ptr<RequestTag>(new RequestTag(endpointOrWorker,
send,
Expand All @@ -39,8 +38,8 @@ RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
size_t length,
ucp_tag_t tag,
const bool enablePythonFuture,
std::function<void(std::shared_ptr<void>)> callbackFunction,
std::shared_ptr<void> callbackData)
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpointOrWorker,
std::make_shared<DelayedSubmission>(send, buffer, length, tag),
std::string(send ? "tagSend" : "tagRecv"),
Expand Down
32 changes: 16 additions & 16 deletions cpp/src/request_tag_multi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ RequestTagMulti::RequestTagMulti(std::shared_ptr<Endpoint> endpoint,
if (enablePythonFuture) _future = worker->getFuture();

ucxx_debug("RequestTagMulti created: %p", this);
callback();
callback(UCS_OK);
}

RequestTagMulti::RequestTagMulti(std::shared_ptr<Endpoint> endpoint,
Expand Down Expand Up @@ -126,7 +126,9 @@ void RequestTagMulti::recvFrames()
buf->getSize(),
_tag,
false,
std::bind(std::mem_fn(&RequestTagMulti::markCompleted), this, std::placeholders::_1),
[this](ucs_status_t status, RequestCallbackUserData arg) {
return this->markCompleted(status, arg);
},
bufferRequest);
bufferRequest->buffer = buf;
ucxx_trace_req("RequestTagMulti::recvFrames request: %p, tag: %lx, buffer: %p",
Expand All @@ -144,7 +146,7 @@ void RequestTagMulti::recvFrames()
_isFilled);
};

void RequestTagMulti::markCompleted(std::shared_ptr<void> request)
void RequestTagMulti::markCompleted(ucs_status_t status, RequestCallbackUserData request)
{
ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: %lx", this, _tag);
std::lock_guard<std::mutex> lock(_completedRequestsMutex);
Expand Down Expand Up @@ -177,13 +179,13 @@ void RequestTagMulti::recvHeader()
auto bufferRequest = std::make_shared<BufferRequest>();
_bufferRequests.push_back(bufferRequest);
bufferRequest->stringBuffer = std::make_shared<std::string>(Header::dataSize(), 0);
bufferRequest->request =
_endpoint->tagRecv(&bufferRequest->stringBuffer->front(),
bufferRequest->stringBuffer->size(),
_tag,
false,
std::bind(std::mem_fn(&RequestTagMulti::callback), this),
nullptr);
bufferRequest->request = _endpoint->tagRecv(
&bufferRequest->stringBuffer->front(),
bufferRequest->stringBuffer->size(),
_tag,
false,
[this](ucs_status_t status, RequestCallbackUserData arg) { return this->callback(status); },
nullptr);

if (bufferRequest->request->isCompleted()) {
// TODO: Errors may not be raisable within callback
Expand All @@ -196,7 +198,7 @@ void RequestTagMulti::recvHeader()
_bufferRequests.empty());
}

void RequestTagMulti::callback()
void RequestTagMulti::callback(ucs_status_t status)
{
if (_send) throw std::runtime_error("Send requests cannot call callback()");

Expand All @@ -208,10 +210,6 @@ void RequestTagMulti::callback()
} else {
const auto request = _bufferRequests.back();

// nullptr/NULL and UCS_OK have the same meaning, request completed immediately.
const ucs_status_t status =
request->request == nullptr ? UCS_OK : request->request->getStatus();

if (status == UCS_OK) {
ucxx_trace_req(
"RequestTagMulti::callback header received, multi request: %p, tag: %lx", this, _tag);
Expand Down Expand Up @@ -267,7 +265,9 @@ void RequestTagMulti::send(const std::vector<void*>& buffer,
size[i],
_tag,
false,
std::bind(std::mem_fn(&RequestTagMulti::markCompleted), this, std::placeholders::_1),
[this](ucs_status_t status, RequestCallbackUserData arg) {
return this->markCompleted(status, arg);
},
bufferRequest);
bufferRequest->request = r;
_bufferRequests.push_back(bufferRequest);
Expand Down
Loading