-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle cancellation notification in gRPC server #6298
Changes from all commits
63e9d56
bb7c2af
9a4522b
34cbdbb
8c5563d
3c27f67
4c46a63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,9 @@ typedef enum { | |
ISSUED, | ||
READ, | ||
WRITEREADY, | ||
WRITTEN | ||
WRITTEN, | ||
CANCELLATION_ISSUED, | ||
CANCELLED | ||
} Steps; | ||
|
||
// Debugging helper | ||
|
@@ -652,12 +654,127 @@ class InferHandlerState { | |
ctx_->set_compression_level(compression_level); | ||
} | ||
|
||
void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state) | ||
{ | ||
ctx_->AsyncNotifyWhenDone(state); | ||
} | ||
|
||
bool IsCancelled() { return ctx_->IsCancelled(); } | ||
|
||
// Increments the ongoing request counter | ||
void IncrementRequestCounter() { ongoing_requests_++; } | ||
|
||
// Decrements the ongoing request counter | ||
void DecrementRequestCounter() { ongoing_requests_--; } | ||
|
||
|
||
// Inserts the state to a set tracking active requests | ||
// within the server core. Should only be called when | ||
// the request was successfully enqueued on Triton. | ||
void InsertInflightState( | ||
InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest) | ||
{ | ||
std::lock_guard<std::mutex> lock(mu_); | ||
// The irequest_ptr_ will get populated when it is | ||
// marked as active which means the request has been | ||
// successfully enqueued to Triton core using | ||
// TRITONSERVER_ServerInferAsync. | ||
state->irequest_ptr_ = irequest; | ||
inflight_states_.insert(state); | ||
} | ||
|
||
// Erases the state to a set tracking active requests | ||
// within the server core. | ||
void EraseInflightState(InferHandlerStateType* state) | ||
{ | ||
std::lock_guard<std::mutex> lock(mu_); | ||
inflight_states_.erase(state); | ||
} | ||
|
||
// Issues the cancellation for all inflight requests | ||
// being tracked by this context. | ||
void IssueRequestCancellation() | ||
{ | ||
{ | ||
std::lock_guard<std::mutex> lock(mu_); | ||
|
||
// Issues the request cancellation to the core. | ||
for (auto state : inflight_states_) { | ||
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_); | ||
if (state->step_ != Steps::CANCELLED) { | ||
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_; | ||
if (state->irequest_ptr_ == nullptr) { | ||
// The context might be holding some states that have | ||
// not been issued to Triton core. Need to skip calling | ||
// issuing cancellation for such requests. | ||
continue; | ||
} | ||
// Note that request may or may not be valid at this point. | ||
// Assuming if RequestComplete callback is run asynchronously | ||
// before this point. | ||
TRITONSERVER_Error* err = nullptr; | ||
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_); | ||
// TODO: Add request id to the message | ||
if (err != nullptr) { | ||
LOG_INFO << "Failed to cancel the request: " | ||
<< TRITONSERVER_ErrorMessage(err); | ||
} | ||
state->step_ = Steps::CANCELLATION_ISSUED; | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
// Handles the gRPC context cancellation. This function can be called | ||
// multiple times and is supposed to be re-entrant. | ||
// Returns whether or not to continue cycling through the gRPC | ||
// completion queue or not. | ||
bool HandleCancellation( | ||
InferHandlerStateType* state, bool rpc_ok, const std::string& name) | ||
{ | ||
if (!IsCancelled()) { | ||
LOG_ERROR | ||
<< "[INTERNAL] HandleCancellation called even when the context was " | ||
"not cancelled for " | ||
<< name << ", rpc_ok=" << rpc_ok << ", context " | ||
<< state->context_->unique_id_ << ", " << state->unique_id_ | ||
<< " step " << state->step_; | ||
return true; | ||
Comment on lines
+736
to
+743
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally this should never happen is that correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. This should be only called if the context was cancelled. |
||
} | ||
if ((state->step_ != Steps::CANCELLATION_ISSUED) && | ||
(state->step_ != Steps::CANCELLED)) { | ||
LOG_VERBOSE(1) << "Cancellation notification received for " << name | ||
<< ", rpc_ok=" << rpc_ok << ", context " | ||
<< state->context_->unique_id_ << ", " | ||
<< state->unique_id_ << " step " << state->step_; | ||
Comment on lines
+747
to
+750
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Side note, not for this PR, but it might be nice to have some lightly wrapped object or helper function around some of these values used for logging purposes: struct RpcCallInfo {
State state;
bool rpc_ok;
};
ostream operator ...
{
out << ", rpc_ok=" << rpc_ok << ", context "
<< state->context_->unique_id_ << ", "
<< state->unique_id_ << " step " << state->step_;
return out;
} So when we call this logging in most of our methods, it can be simplified to something like: LOG_VERBOSE(1) << "Cancellation notification received for: " << RpcCallInfo(state, rpc_ok); Or something along those lines. |
||
|
||
// If the context has not been cancelled then | ||
// issue cancellation request to all the inflight | ||
// states belonging to the context. | ||
if (state->context_->step_ != Steps::CANCELLED) { | ||
IssueRequestCancellation(); | ||
// Mark the context as cancelled | ||
state->context_->step_ = Steps::CANCELLED; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only have a single ModelStreamInferHandler thread. Hence, we don't need a lock around the context step_. |
||
|
||
// The state returns true because the CancelExecution | ||
// call above would have raised alarm objects on all | ||
// pending inflight states objects. This state will | ||
// be taken up along with all the other states in the | ||
// next iteration from the completion queue which | ||
// would release the state. | ||
return true; | ||
} | ||
} | ||
|
||
LOG_VERBOSE(1) << "Completing cancellation for " << name | ||
<< ", rpc_ok=" << rpc_ok << ", context " | ||
<< state->context_->unique_id_ << ", " << state->unique_id_ | ||
<< " step " << state->step_; | ||
|
||
return false; | ||
} | ||
|
||
// Enqueue 'state' so that its response is delivered in the | ||
// correct order. | ||
void EnqueueForResponse(InferHandlerStateType* state) | ||
|
@@ -781,6 +898,11 @@ class InferHandlerState { | |
std::queue<InferHandlerStateType*> states_; | ||
std::atomic<uint32_t> ongoing_requests_; | ||
|
||
// Tracks the inflight requests sent to Triton core via this | ||
// context. We will use this structure to issue cancellations | ||
// on these requests. | ||
std::set<InferHandlerStateType*> inflight_states_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By using a set, are we possibly hiding any bugs/errors where we insert a duplicate state, with the parallel nature of Handler/ResponseCallback/AsyncNotifyWhenDone logic for adding states onto queue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The items are added in this data structure iff the the InferAsync function was successfully called. This will only be done once per state object. Hence, duplication should not be an issue. |
||
|
||
// The step of the entire context. | ||
Steps step_; | ||
|
||
|
@@ -809,12 +931,15 @@ class InferHandlerState { | |
|
||
~InferHandlerState() { ClearTraceTimestamps(); } | ||
|
||
bool IsGrpcContextCancelled() { return context_->IsCancelled(); } | ||
|
||
void Reset( | ||
const std::shared_ptr<Context>& context, Steps start_step = Steps::START) | ||
{ | ||
unique_id_ = NEXT_UNIQUE_ID; | ||
context_ = context; | ||
step_ = start_step; | ||
irequest_ptr_ = nullptr; | ||
cb_count_ = 0; | ||
is_decoupled_ = false; | ||
complete_ = false; | ||
|
@@ -859,7 +984,9 @@ class InferHandlerState { | |
|
||
std::shared_ptr<Context> context_; | ||
Steps step_; | ||
std::mutex step_mtx_; | ||
std::recursive_mutex step_mtx_; | ||
rmccorm4 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
TRITONSERVER_InferenceRequest* irequest_ptr_; | ||
|
||
#ifdef TRITON_ENABLE_TRACING | ||
std::shared_ptr<TraceManager::Trace> trace_; | ||
|
@@ -939,6 +1066,10 @@ class InferHandler : public HandlerBase { | |
state = new State(tritonserver, context, start_step); | ||
} | ||
|
||
// Need to be called to receive an asynchronous notification | ||
// when the transaction is cancelled. | ||
context->GrpcContextAsyncNotifyWhenDone(state); | ||
|
||
return state; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to mention the logic behind saving/checking the request object for sanity checking state cleanup/reuse in CQ, as we discussed offline