Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timeout generic callbacks used in destructors and inflight request cancelation #123

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,21 @@ class Endpoint : public Component {
* This is usually executed by `close()`, when pending requests will no longer be able
* to complete.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Endpoint()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
* @returns Number of requests that were canceled.
*/
size_t cancelInflightRequests();
size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);

/**
* @brief Register a user-defined callback to call when endpoint closes.
Expand Down Expand Up @@ -507,8 +519,21 @@ class Endpoint : public Component {
* If the endpoint was created with error handling support, the error callback will be
* executed, implying the user-defined callback will also be executed if one was
* registered with `setCloseCallback()`.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Endpoint()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
*/
void close();
void close(uint64_t period = 0, uint64_t maxAttempts = 1);
};

} // namespace ucxx
11 changes: 9 additions & 2 deletions cpp/include/ucxx/utils/callback_notifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,18 @@ class CallbackNotifier {
void set();

/**
* @brief Wait until `set()` has been called
* @brief Wait until `set()` has been called or period has elapsed.
*
* Wait until `set()` has been called, or period (in nanoseconds) has elapsed (only
* applicable if using glibc 2.25 and higher).
*
* See also `std::condition_variable::wait`.
*
* @param[in] period maximum period in nanoseconds to wait for or `0` to wait forever.
*
* @return `true` if waiting finished or `false` if a timeout occurred.
*/
void wait();
bool wait(uint64_t period = 0);
};

} // namespace utils
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,21 @@ class Worker : public Component {
* Cancel inflight requests, returning the total number of requests that were canceled.
* This is usually executed during the progress loop.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Worker()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
* @returns Number of requests that were canceled.
*/
size_t cancelInflightRequests();
size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);

/**
* @brief Schedule cancelation of inflight requests.
Expand Down
88 changes: 54 additions & 34 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Worker

Endpoint::~Endpoint()
{
close();
close(10000000000 /* 10s */);
ucxx_trace("Endpoint destroyed: %p, UCP handle: %p", this, _originalHandle);
}

void Endpoint::close()
void Endpoint::close(uint64_t period, uint64_t maxAttempts)
{
if (_handle == nullptr) return;

size_t canceled = cancelInflightRequests();
size_t canceled = cancelInflightRequests(3000000000 /* 3s */, 3);
ucxx_debug("Endpoint %p canceled %lu requests", _handle, canceled);

// Close the endpoint
Expand All @@ -161,29 +161,39 @@ void Endpoint::close()
ucs_status_ptr_t status;

if (worker->isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &status, closeMode]() {
status = ucp_ep_close_nb(_handle, closeMode);
callbackNotifierPre.set();
});
callbackNotifierPre.wait();

while (UCS_PTR_IS_PTR(status)) {
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([this, &callbackNotifierPost, &status]() {
ucs_status_t s = ucp_request_check_status(status);
if (UCS_PTR_STATUS(s) != UCS_INPROGRESS) {
ucp_request_free(status);
_callbackData->status = UCS_PTR_STATUS(s);
if (UCS_PTR_STATUS(status) != UCS_OK) {
ucxx_error("Error while closing endpoint: %s",
ucs_status_string(UCS_PTR_STATUS(status)));
bool closeSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &status, closeMode]() {
status = ucp_ep_close_nb(_handle, closeMode);
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

while (UCS_PTR_IS_PTR(status)) {
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([this, &callbackNotifierPost, &status]() {
ucs_status_t s = ucp_request_check_status(status);
if (UCS_PTR_STATUS(s) != UCS_INPROGRESS) {
ucp_request_free(status);
_callbackData->status = UCS_PTR_STATUS(s);
if (UCS_PTR_STATUS(status) != UCS_OK) {
ucxx_error("Error while closing endpoint: %s",
ucs_status_string(UCS_PTR_STATUS(status)));
}
}
}

callbackNotifierPost.set();
});
callbackNotifierPost.wait();
callbackNotifierPost.set();
});
if (!callbackNotifierPost.wait(period)) continue;
}

closeSuccess = true;
}

if (!closeSuccess) {
_callbackData->status = UCS_ERR_ENDPOINT_TIMEOUT;
ucxx_error("All attempts to close timed out on endpoint: %p, UCP handle: %p", this, _handle);
}
} else {
status = ucp_ep_close_nb(_handle, closeMode);
Expand Down Expand Up @@ -257,7 +267,7 @@ void Endpoint::removeInflightRequest(const Request* const request)
_inflightRequests->remove(request);
}

size_t Endpoint::cancelInflightRequests()
size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts)
{
auto worker = ::ucxx::getWorker(this->_parent);
size_t canceled = 0;
Expand All @@ -266,15 +276,25 @@ size_t Endpoint::cancelInflightRequests()
canceled = _inflightRequests->cancelAll();
worker->progress();
} else if (worker->isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &canceled]() {
canceled = _inflightRequests->cancelAll();
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
bool cancelSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &canceled]() {
canceled = _inflightRequests->cancelAll();
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
if (!callbackNotifierPost.wait(period)) continue;

cancelSuccess = true;
}
if (!cancelSuccess)
ucxx_error("All attempts to cancel inflight requests failed on endpoint: %p, UCP handle: %p",
this,
_handle);
} else {
canceled = _inflightRequests->cancelAll();
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ Listener::~Listener()
ucp_listener_destroy(_handle);
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
callbackNotifierPre.wait(10000000000 /* 10s */);

utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
callbackNotifierPost.wait(10000000000 /* 10s */);
} else {
ucp_listener_destroy(_handle);
worker->progress();
Expand Down
13 changes: 11 additions & 2 deletions cpp/src/utils/callback_notifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,24 @@ void CallbackNotifier::set()
_conditionVariable.notify_all();
}
}
void CallbackNotifier::wait()

bool CallbackNotifier::wait(uint64_t period)
{
if (_useSpinlock) {
while (!_flag.load(std::memory_order_acquire)) {}
wence- marked this conversation as resolved.
Show resolved Hide resolved
} else {
std::unique_lock lock(_mutex);
// Likewise here, the mutex provides ordering.
_conditionVariable.wait(lock, [this]() { return _flag.load(std::memory_order_relaxed); });
if (period > 0) {
return _conditionVariable.wait_for(
lock, std::chrono::duration<uint64_t, std::nano>(period), [this]() {
return _flag.load(std::memory_order_relaxed);
});
} else {
_conditionVariable.wait(lock, [this]() { return _flag.load(std::memory_order_relaxed); });
}
}
return true;
}

} // namespace utils
Expand Down
34 changes: 22 additions & 12 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,

Worker::~Worker()
{
size_t canceled = cancelInflightRequests();
size_t canceled = cancelInflightRequests(3000000000 /* 3s */, 3);
ucxx_debug("Worker %p canceled %lu requests", _handle, canceled);

stopProgressThreadNoWarn();
Expand Down Expand Up @@ -266,7 +266,7 @@ bool Worker::progress()
if (progressScheduledCancel) ret |= progressPending();

// Requests that were not completed now must be canceled.
if (cancelInflightRequests() > 0) ret |= progressPending();
if (cancelInflightRequests(3000000000 /* 3s */, 3) > 0) ret |= progressPending();

return ret;
}
Expand Down Expand Up @@ -399,7 +399,7 @@ bool Worker::isProgressThreadRunning() { return _progressThread != nullptr; }

std::thread::id Worker::getProgressThreadId() { return _progressThreadId; }

size_t Worker::cancelInflightRequests()
size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts)
{
size_t canceled = 0;

Expand All @@ -413,16 +413,26 @@ size_t Worker::cancelInflightRequests()
canceled = inflightRequestsToCancel->cancelAll();
progressPending();
} else if (isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
registerGenericPre([&callbackNotifierPre, &canceled, &inflightRequestsToCancel]() {
canceled = inflightRequestsToCancel->cancelAll();
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
bool cancelSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
registerGenericPre([&callbackNotifierPre, &canceled, &inflightRequestsToCancel]() {
canceled = inflightRequestsToCancel->cancelAll();
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

utils::CallbackNotifier callbackNotifierPost{};
registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
if (!callbackNotifierPost.wait(period)) continue;

cancelSuccess = true;
}

utils::CallbackNotifier callbackNotifierPost{};
registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
if (!cancelSuccess)
ucxx_error("All attempts to cancel inflight requests failed on worker: %p, UCP handle: %p",
this,
_handle);
} else {
canceled = inflightRequestsToCancel->cancelAll();
}
Expand Down