Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle cancellation notification in gRPC server #6298

Merged
merged 7 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ operator<<(std::ostream& out, const Steps& step)
case WRITTEN:
out << "WRITTEN";
break;
case CANCELLATION_ISSUED:
out << "CANCELLATION_ISSUED";
break;
case CANCELLED:
out << "CANCELLED";
break;
}

return out;
Expand Down Expand Up @@ -707,6 +713,21 @@ ModelInferHandler::StartNewRequest()
bool
ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok)
{
// There are multiple handlers registered in the gRPC service.
// Hence, there we can have a case where a handler thread is
// making progress in the state machine for a request and the
// other thread is issuing cancellation on the same request.
// Need to protect the state transitions for these cases.
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

// Handle notification for cancellation which can be raised
// asynchronously if detected on the network.
if (state->IsGrpcContextCancelled()) {
bool resume = state->context_->HandleCancellation(state, rpc_ok, Name());
return resume;
}


LOG_VERBOSE(1) << "Process for " << Name() << ", rpc_ok=" << rpc_ok << ", "
<< state->unique_id_ << " step " << state->step_;

Expand Down Expand Up @@ -933,8 +954,12 @@ ModelInferHandler::Execute(InferHandler::State* state)

// If not error then state->step_ == ISSUED and inference request
// has initiated... completion callback will transition to
// COMPLETE. If error go immediately to COMPLETE.
if (err != nullptr) {
// COMPLETE or CANCELLED. Recording the state and the irequest
// to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
} else {
// If error go immediately to COMPLETE.
LOG_VERBOSE(1) << "[request id: " << request_id << "] "
<< "Infer failed: " << TRITONSERVER_ErrorMessage(err);

Expand Down Expand Up @@ -965,6 +990,12 @@ ModelInferHandler::InferResponseComplete(
{
State* state = reinterpret_cast<State*>(userp);

// There are multiple handlers registered in the gRPC service
// Hence, we would need to properly synchronize this thread
// and the handler thread handling async cancellation
// notification.
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

// Increment the callback index
state->cb_count_++;

Expand All @@ -982,6 +1013,29 @@ ModelInferHandler::InferResponseComplete(
"INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp()));
#endif // TRITON_ENABLE_TRACING

// If gRPC Stream is cancelled then no need of forming and returning
// a response.
if (state->IsGrpcContextCancelled()) {
// Clean-up the received response object.
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(iresponse),
"deleting GRPC inference response");

state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);

LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
<< state->unique_id_
<< ", skipping response generation as grpc transaction was "
"cancelled... ";

// Send state back to the queue so that state can be released
// in the next cycle.
state->context_->PutTaskBackToQueue(state);

return;
}

TRITONSERVER_Error* err = nullptr;
// This callback is expected to be called exactly once for each request.
// Will use the single response object in the response list to hold the
Expand Down
135 changes: 133 additions & 2 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ typedef enum {
ISSUED,
READ,
WRITEREADY,
WRITTEN
WRITTEN,
CANCELLATION_ISSUED,
CANCELLED
} Steps;

// Debugging helper
Expand Down Expand Up @@ -652,12 +654,127 @@ class InferHandlerState {
ctx_->set_compression_level(compression_level);
}

void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state)
{
ctx_->AsyncNotifyWhenDone(state);
}

bool IsCancelled() { return ctx_->IsCancelled(); }

// Increments the ongoing request counter
void IncrementRequestCounter() { ongoing_requests_++; }

// Decrements the ongoing request counter
void DecrementRequestCounter() { ongoing_requests_--; }


// Inserts the state to a set tracking active requests
// within the server core. Should only be called when
// the request was successfully enqueued on Triton.
void InsertInflightState(
InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest)
{
std::lock_guard<std::mutex> lock(mu_);
// The irequest_ptr_ will get populated when it is
// marked as active which means the request has been
// successfully enqueued to Triton core using
// TRITONSERVER_ServerInferAsync.
state->irequest_ptr_ = irequest;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to mention the logic behind saving/checking the request object for sanity checking state cleanup/reuse in CQ, as we discussed offline

inflight_states_.insert(state);
}

// Erases the state to a set tracking active requests
// within the server core.
void EraseInflightState(InferHandlerStateType* state)
{
std::lock_guard<std::mutex> lock(mu_);
inflight_states_.erase(state);
}

// Issues the cancellation for all inflight requests
// being tracked by this context.
void IssueRequestCancellation()
{
{
std::lock_guard<std::mutex> lock(mu_);

// Issues the request cancellation to the core.
for (auto state : inflight_states_) {
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
if (state->step_ != Steps::CANCELLED) {
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
if (state->irequest_ptr_ == nullptr) {
// The context might be holding some states that have
// not been issued to Triton core. Need to skip calling
// issuing cancellation for such requests.
continue;
}
// Note that request may or may not be valid at this point.
// Assuming if RequestComplete callback is run asynchronously
// before this point.
TRITONSERVER_Error* err = nullptr;
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_);
// TODO: Add request id to the message
if (err != nullptr) {
LOG_INFO << "Failed to cancel the request: "
<< TRITONSERVER_ErrorMessage(err);
}
state->step_ = Steps::CANCELLATION_ISSUED;
}
}
}
}


// Handles the gRPC context cancellation. This function can be called
// multiple times and is supposed to be re-entrant.
// Returns whether or not to continue cycling through the gRPC
// completion queue or not.
bool HandleCancellation(
InferHandlerStateType* state, bool rpc_ok, const std::string& name)
{
if (!IsCancelled()) {
LOG_ERROR
<< "[INTERNAL] HandleCancellation called even when the context was "
"not cancelled for "
<< name << ", rpc_ok=" << rpc_ok << ", context "
<< state->context_->unique_id_ << ", " << state->unique_id_
<< " step " << state->step_;
return true;
Comment on lines +736 to +743
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this should never happen is that correct?

Copy link
Contributor Author

@tanmayv25 tanmayv25 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This should be only called if the context was cancelled.

}
if ((state->step_ != Steps::CANCELLATION_ISSUED) &&
(state->step_ != Steps::CANCELLED)) {
LOG_VERBOSE(1) << "Cancellation notification received for " << name
<< ", rpc_ok=" << rpc_ok << ", context "
<< state->context_->unique_id_ << ", "
<< state->unique_id_ << " step " << state->step_;
Comment on lines +747 to +750
Copy link
Contributor

@rmccorm4 rmccorm4 Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note, not for this PR, but it might be nice to have some lightly wrapped object or helper function around some of these values used for logging purposes:

struct RpcCallInfo {
  State state;
  bool rpc_ok;
};

ostream operator ...
{
  out <<  ", rpc_ok=" << rpc_ok << ", context "
        << state->context_->unique_id_ << ", "
        << state->unique_id_ << " step " << state->step_;
  return out;
}

So when we call this logging in most of our methods, it can be simplified to something like:

LOG_VERBOSE(1) << "Cancellation notification received for: " << RpcCallInfo(state, rpc_ok);

Or something along those lines.


// If the context has not been cancelled then
// issue cancellation request to all the inflight
// states belonging to the context.
if (state->context_->step_ != Steps::CANCELLED) {
IssueRequestCancellation();
// Mark the context as cancelled
state->context_->step_ = Steps::CANCELLED;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the step to be modified by another thread? Do we need to project it using a mutex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only have a single ModelStreamInferHandler thread. Hence, we don't need a lock around the context step_.


// The state returns true because the CancelExecution
// call above would have raised alarm objects on all
// pending inflight states objects. This state will
// be taken up along with all the other states in the
// next iteration from the completion queue which
// would release the state.
return true;
}
}

LOG_VERBOSE(1) << "Completing cancellation for " << name
<< ", rpc_ok=" << rpc_ok << ", context "
<< state->context_->unique_id_ << ", " << state->unique_id_
<< " step " << state->step_;

return false;
}

// Enqueue 'state' so that its response is delivered in the
// correct order.
void EnqueueForResponse(InferHandlerStateType* state)
Expand Down Expand Up @@ -781,6 +898,11 @@ class InferHandlerState {
std::queue<InferHandlerStateType*> states_;
std::atomic<uint32_t> ongoing_requests_;

// Tracks the inflight requests sent to Triton core via this
// context. We will use this structure to issue cancellations
// on these requests.
std::set<InferHandlerStateType*> inflight_states_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using a set, are we possibly hiding any bugs/errors where we insert a duplicate state, with the parallel nature of Handler/ResponseCallback/AsyncNotifyWhenDone logic for adding states onto queue?

Copy link
Contributor Author

@tanmayv25 tanmayv25 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The items are added in this data structure iff the the InferAsync function was successfully called. This will only be done once per state object. Hence, duplication should not be an issue.


// The step of the entire context.
Steps step_;

Expand Down Expand Up @@ -809,12 +931,15 @@ class InferHandlerState {

~InferHandlerState() { ClearTraceTimestamps(); }

bool IsGrpcContextCancelled() { return context_->IsCancelled(); }

void Reset(
const std::shared_ptr<Context>& context, Steps start_step = Steps::START)
{
unique_id_ = NEXT_UNIQUE_ID;
context_ = context;
step_ = start_step;
irequest_ptr_ = nullptr;
cb_count_ = 0;
is_decoupled_ = false;
complete_ = false;
Expand Down Expand Up @@ -859,7 +984,9 @@ class InferHandlerState {

std::shared_ptr<Context> context_;
Steps step_;
std::mutex step_mtx_;
std::recursive_mutex step_mtx_;
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved

TRITONSERVER_InferenceRequest* irequest_ptr_;

#ifdef TRITON_ENABLE_TRACING
std::shared_ptr<TraceManager::Trace> trace_;
Expand Down Expand Up @@ -939,6 +1066,10 @@ class InferHandler : public HandlerBase {
state = new State(tritonserver, context, start_step);
}

// Need to be called to receive an asynchronous notification
// when the transaction is cancelled.
context->GrpcContextAsyncNotifyWhenDone(state);

return state;
}

Expand Down
Loading
Loading