Skip to content

Commit

Permalink
Fix request complete callback by removing reference to state
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmayv25 committed Sep 13, 2023
1 parent b55662f commit 33d58ca
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 94 deletions.
37 changes: 15 additions & 22 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,19 @@ InferGRPCToInput(
return nullptr; // success
}

void
InferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete";

if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting GRPC inference request");
}
}

//===========================================================================
// The following section contains the handling mechanism for ModelInfer RPC.
// This implementation is tuned towards performance and reducing latency.
Expand Down Expand Up @@ -905,8 +918,7 @@ ModelInferHandler::Execute(InferHandler::State* state)
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, InferRequestComplete,
reinterpret_cast<void*>(state) /* request_release_userp */);
irequest, InferRequestComplete, nullptr /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -971,26 +983,6 @@ ModelInferHandler::Execute(InferHandler::State* state)
}
}

void
ModelInferHandler::InferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete";

State* state = reinterpret_cast<State*>(userp);
if (state->irequest_ptr_ != request) {
LOG_ERROR << "[INTERNAL] ModelInferHandler::InferRequestComplete: "
"TRITONSERVER_InferenceRequest ptr mismatch detected";
}
state->context_->EraseInflightState(state);

if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting GRPC inference request");
}
}

void
ModelInferHandler::InferResponseComplete(
TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags,
Expand Down Expand Up @@ -1030,6 +1022,7 @@ ModelInferHandler::InferResponseComplete(
"deleting GRPC inference response");

state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);

LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
<< state->unique_id_
Expand Down
64 changes: 20 additions & 44 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ TRITONSERVER_Error* SetStateParameterFromTritonParameter(
StateParameters& state_params,
const std::pair<std::string, inference::InferParameter>& param);

void InferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp);

// Make sure to keep InferResponseAlloc and OutputBufferQuery logic in sync
TRITONSERVER_Error* OutputBufferQuery(
TRITONSERVER_ResponseAllocator* allocator, void* userp,
Expand Down Expand Up @@ -685,51 +688,33 @@ class InferHandlerState {

// Issues the cancellation for all inflight requests
// being tracked by this context.
TRITONSERVER_Error* IssueRequestCancellation()
void IssueRequestCancellation()
{
TRITONSERVER_Error* err = nullptr;
{
std::lock_guard<std::mutex> lock(mu_);

// Issues the request cancellation to the core.
for (auto state : inflight_states_) {
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
if (state->irequest_ptr_ == nullptr) {
continue;
}
bool is_cancelled = false;
err = TRITONSERVER_InferenceRequestIsCancelled(
state->irequest_ptr_, &is_cancelled);
if ((err == nullptr) && (!is_cancelled)) {
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_);
}
if (err != nullptr) {
const char* request_id = "";
TRITONSERVER_Error* tmp_err = TRITONSERVER_InferenceRequestId(
state->irequest_ptr_, &request_id);
if (tmp_err != nullptr) {
// Ignore if failed to get request id.
TRITONSERVER_ErrorDelete(tmp_err);
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
if (state->step_ != Steps::CANCELLED) {
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
if (state->irequest_ptr_ == nullptr) {
continue;
}
TRITONSERVER_Error* wrapped_err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
std::string("[request id:") + request_id +
"]: " + TRITONSERVER_ErrorMessage(err))
.c_str());
TRITONSERVER_ErrorDelete(err);
return wrapped_err;
}

{
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
if (state->step_ != Steps::CANCELLED) {
state->step_ = Steps::CANCELLATION_ISSUED;
// Note that request may or may not be valid at this point.
// Assuming if RequestComplete callback is run asynchronously
// before this point.
TRITONSERVER_Error* err = nullptr;
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_);
// TODO: Add request id to the message
if (err != nullptr) {
LOG_INFO << "Failed to cancel the request: "
<< TRITONSERVER_ErrorMessage(err);
}
state->step_ = Steps::CANCELLATION_ISSUED;
}
}
}
return nullptr;
}


Expand Down Expand Up @@ -760,13 +745,7 @@ class InferHandlerState {
// issue cancellation request to all the inflight
// states belonging to the context.
if (state->context_->step_ != Steps::CANCELLED) {
TRITONSERVER_Error* err = nullptr;
err = IssueRequestCancellation();
if (err != nullptr) {
LOG_VERBOSE(1) << "Failed to cancel request: "
<< TRITONSERVER_ErrorMessage(err);
TRITONSERVER_ErrorDelete(err);
}
IssueRequestCancellation();
// Mark the context as cancelled
state->context_->step_ = Steps::CANCELLED;

Expand Down Expand Up @@ -1305,9 +1284,6 @@ class ModelInferHandler

private:
void Execute(State* state);
static void InferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags,
void* userp);
static void InferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags,
void* userp);
Expand Down
30 changes: 5 additions & 25 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, StreamInferRequestComplete,
reinterpret_cast<void*>(state) /* request_release_userp */);
irequest, InferRequestComplete, nullptr /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -310,7 +309,7 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// state->step_ == ISSUED and inference request has
// initiated... the completion callback will transition to
// WRITEREADY or WRITTEN or CANCELLED. Recording the state and the
// irequest to handle gRPC stream cancellation.
// irequest to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
} else {
Expand Down Expand Up @@ -529,28 +528,6 @@ ModelStreamInferHandler::Finish(InferHandler::State* state)
return false;
}

void
ModelStreamInferHandler::StreamInferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
LOG_VERBOSE(1) << "ModelStreamInferHandler::StreamInferRequestComplete";

State* state = reinterpret_cast<State*>(userp);
if (state->irequest_ptr_ != request) {
LOG_ERROR
<< "[INTERNAL] ModelStreamInferHandler::StreamInferRequestComplete: "
"TRITONSERVER_InferenceRequest ptr mismatch detected";
}

state->context_->EraseInflightState(state);

if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting GRPC inference request");
}
}

void
ModelStreamInferHandler::StreamInferResponseComplete(
TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags,
Expand Down Expand Up @@ -603,6 +580,8 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// that state object can be released.
if (state->complete_) {
state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);

state->context_->PutTaskBackToQueue(state);
}

Expand Down Expand Up @@ -701,6 +680,7 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// that state object can be released.
if (state->complete_) {
state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);
state->context_->PutTaskBackToQueue(state);
}

Expand Down
3 changes: 0 additions & 3 deletions src/grpc/stream_infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ class ModelStreamInferHandler
bool Process(State* state, bool rpc_ok) override;

private:
static void StreamInferRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags,
void* userp);
static void StreamInferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags,
void* userp);
Expand Down

0 comments on commit 33d58ca

Please sign in to comment.