From 3032c435e00392ff439b21da8db2c71ca900c438 Mon Sep 17 00:00:00 2001 From: Yingge He <157551214+yinggeh@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:23:29 -0800 Subject: [PATCH] fix: gRPC segfault due to Low Request Cancellation Timeout (#7840) --- qa/L0_grpc_state_cleanup/test.sh | 1 - .../grpc_cancellation_test.py | 62 +++++++-- qa/L0_request_cancellation/test.sh | 32 ++++- src/grpc/infer_handler.cc | 39 ++++-- src/grpc/infer_handler.h | 126 +++++++++++++----- src/grpc/stream_infer_handler.cc | 6 +- src/grpc/stream_infer_handler.h | 3 +- 7 files changed, 202 insertions(+), 67 deletions(-) diff --git a/qa/L0_grpc_state_cleanup/test.sh b/qa/L0_grpc_state_cleanup/test.sh index df302d5ed1..9def779b72 100755 --- a/qa/L0_grpc_state_cleanup/test.sh +++ b/qa/L0_grpc_state_cleanup/test.sh @@ -56,7 +56,6 @@ function check_state_release() { num_state_new=`cat $log_file | grep "StateNew" | wc -l` if [ $num_state_release -ne $num_state_new ]; then - cat $log_file echo -e "\n***\n*** Test Failed: Mismatch detected, $num_state_new state(s) created, $num_state_release state(s) released. \n***" >> $log_file return 1 fi diff --git a/qa/L0_request_cancellation/grpc_cancellation_test.py b/qa/L0_request_cancellation/grpc_cancellation_test.py index 4b103e21e1..7d9331d5f5 100755 --- a/qa/L0_request_cancellation/grpc_cancellation_test.py +++ b/qa/L0_request_cancellation/grpc_cancellation_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio +import os import queue import re import time @@ -169,28 +170,61 @@ def test_grpc_async_infer_cancellation_at_step_start(self): with open(server_log_name, "r") as f: server_log = f.read() - cancel_at_start_count = len( - re.findall( - r"Cancellation notification received for ModelInferHandler, rpc_ok=1, context \d+, \d+ step START", - server_log, - ) - ) cur_new_req_handl_count = len( re.findall("New request handler for ModelInferHandler", server_log) ) - self.assertEqual( - cancel_at_start_count, - 2, - "Expected 2 cancellation at step START log entries, but got {}".format( - cancel_at_start_count - ), - ) self.assertGreater( cur_new_req_handl_count, prev_new_req_handl_count, "gRPC Cancellation on step START Test Failed: New request handler for ModelInferHandler was not created", ) + def test_grpc_async_infer_response_complete_during_cancellation(self): + # long test + self.test_duration_delta = 2 + delay_notification_sec = ( + int(os.getenv("TRITONSERVER_DELAY_GRPC_NOTIFICATION")) / 1000 + ) + delay_queue_cancellation_sec = ( + int(os.getenv("TRITONSERVER_DELAY_GRPC_ENQUEUE")) / 1000 + ) + future = self._client.async_infer( + model_name=self._model_name, + inputs=self._inputs, + callback=self._callback, + outputs=self._outputs, + ) + # ensure cancellation is received before InferResponseComplete and is processed after InferResponseComplete + time.sleep(self._model_delay - 2) + future.cancel() + time.sleep( + delay_notification_sec + delay_queue_cancellation_sec + ) # ensure the cancellation is processed + self._assert_callback_cancelled() + + def test_grpc_async_infer_cancellation_during_response_complete(self): + # long test + self.test_duration_delta = 2.5 + delay_notification_sec = ( + int(os.getenv("TRITONSERVER_DELAY_GRPC_NOTIFICATION")) / 1000 + ) + delay_response_completion_sec = ( + int(os.getenv("TRITONSERVER_DELAY_RESPONSE_COMPLETION")) / 1000 + ) + future = self._client.async_infer( + model_name=self._model_name, + inputs=self._inputs, + callback=self._callback, + outputs=self._outputs, + ) + # ensure the cancellation is received between InferResponseComplete checking cancellation and Finish + time.sleep(self._model_delay + 2) + future.cancel() + time.sleep( + delay_notification_sec + delay_response_completion_sec + ) # ensure the cancellation is processed + self._assert_callback_cancelled() + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_request_cancellation/test.sh b/qa/L0_request_cancellation/test.sh index 0c9ab74086..7789f2006a 100755 --- a/qa/L0_request_cancellation/test.sh +++ b/qa/L0_request_cancellation/test.sh @@ -58,7 +58,7 @@ mkdir -p models/model/1 && (cd models/model && \ echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt) SERVER_LOG=server.log -LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH ./request_cancellation_test > $SERVER_LOG +LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH ./request_cancellation_test > $SERVER_LOG 2>&1 if [ $? -ne 0 ]; then echo -e "\n***\n*** Unit Tests Failed\n***" cat $SERVER_LOG @@ -78,12 +78,23 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \ echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \ echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt) -for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer" "test_grpc_async_infer_cancellation_at_step_start"; do - +for TEST_CASE in "test_grpc_async_infer" \ + "test_grpc_stream_infer" \ + "test_aio_grpc_async_infer" \ + "test_aio_grpc_stream_infer" \ + "test_grpc_async_infer_cancellation_at_step_start" \ + "test_grpc_async_infer_response_complete_during_cancellation" \ + "test_grpc_async_infer_cancellation_during_response_complete"; do TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log" SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log" if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then export TRITONSERVER_DELAY_GRPC_PROCESS=5000 + elif [ "$TEST_CASE" == "test_grpc_async_infer_response_complete_during_cancellation" ]; then + export TRITONSERVER_DELAY_GRPC_NOTIFICATION=5000 + export TRITONSERVER_DELAY_GRPC_ENQUEUE=5000 + elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_during_response_complete" ]; then + export TRITONSERVER_DELAY_GRPC_NOTIFICATION=5000 + export TRITONSERVER_DELAY_RESPONSE_COMPLETION=5000 fi SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1" @@ -101,11 +112,16 @@ for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc cat $TEST_LOG RET=1 fi - grep "Cancellation notification received for" $SERVER_LOG - if [ $? -ne 0 ]; then + + count=$(grep -o "Cancellation notification received for" $SERVER_LOG | wc -l) + if [ $count == 0 ]; then echo -e "\n***\n*** Cancellation not received by server on $TEST_CASE\n***" cat $SERVER_LOG RET=1 + elif [ $count -ne 1 ]; then + echo -e "\n***\n*** Unexpected cancellation received by server on $TEST_CASE. Expected 1 but received $count.\n***" + cat $SERVER_LOG + RET=1 fi set -e @@ -114,6 +130,12 @@ for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then unset TRITONSERVER_DELAY_GRPC_PROCESS + elif [ "$TEST_CASE" == "test_grpc_async_infer_response_complete_during_cancellation" ]; then + unset TRITONSERVER_DELAY_GRPC_NOTIFICATION + unset TRITONSERVER_DELAY_GRPC_ENQUEUE + elif [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_during_response_complete" ]; then + unset TRITONSERVER_DELAY_GRPC_NOTIFICATION + unset TRITONSERVER_DELAY_RESPONSE_COMPLETION fi done diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 27c6817324..9c7eef48bb 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -677,7 +677,8 @@ ModelInferHandler::StartNewRequest() } bool -ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) +ModelInferHandler::Process( + InferHandler::State* state, bool rpc_ok, bool is_notification) { // There are multiple handlers registered in the gRPC service. // Hence, there we can have a case where a handler thread is @@ -690,8 +691,8 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) // Will delay the Process execution by the specified time. // This can be used to test the flow when cancellation request // issued for the request, which is still at START step. - LOG_INFO << "Delaying the write of the response by " - << state->delay_process_ms_ << " ms..."; + LOG_INFO << "Delaying the Process execution by " << state->delay_process_ms_ + << " ms..."; std::this_thread::sleep_for( std::chrono::milliseconds(state->delay_process_ms_)); } @@ -711,11 +712,12 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) // because we will never leave this if body. Refer to PR 7325. // This is a special case for ModelInferHandler, since we have 2 threads, // and each of them can process cancellation. ModelStreamInfer has only 1 - // thread, and cancellation at step START was not reproducible in a + // thread, and cancellation at step START was not reproducible in a // single thread scenario. StartNewRequest(); } - bool resume = state->context_->HandleCancellation(state, rpc_ok, Name()); + bool resume = state->context_->HandleCancellation( + state, rpc_ok, Name(), is_notification); return resume; } @@ -765,7 +767,7 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) std::make_pair("GRPC_SEND_START", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - state->step_ = COMPLETE; + state->step_ = Steps::COMPLETE; state->context_->responder_->Finish( inference::ModelInferResponse(), status, state); } @@ -1001,7 +1003,7 @@ ModelInferHandler::Execute(InferHandler::State* state) } #endif // TRITON_ENABLE_TRACING - state->step_ = COMPLETE; + state->step_ = Steps::COMPLETE; state->context_->responder_->Finish(error_response, status, state); } } @@ -1051,6 +1053,17 @@ ModelInferHandler::InferResponseComplete( << ", skipping response generation as grpc transaction was " "cancelled... "; + if (state->delay_enqueue_ms_ != 0) { + // Will delay PutTaskBackToQueue by the specified time. + // This can be used to test the flow when cancellation request + // issued for the request during InferResponseComplete + // callback right before Process in the notification thread. + LOG_INFO << "Delaying PutTaskBackToQueue by " << state->delay_enqueue_ms_ + << " ms..."; + std::this_thread::sleep_for( + std::chrono::milliseconds(state->delay_enqueue_ms_)); + } + // Send state back to the queue so that state can be released // in the next cycle. state->context_->PutTaskBackToQueue(state); @@ -1113,7 +1126,17 @@ ModelInferHandler::InferResponseComplete( std::make_pair("GRPC_SEND_START", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - state->step_ = COMPLETE; + if (state->delay_response_completion_ms_ != 0) { + // Will delay the Process execution of state at step COMPLETE by the + // specified time. This can be used to test the flow when cancellation + // request issued for the request, which is at InferResponseComplete. + LOG_INFO << "Delaying InferResponseComplete by " + << state->delay_response_completion_ms_ << " ms..."; + std::this_thread::sleep_for( + std::chrono::milliseconds(state->delay_response_completion_ms_)); + } + + state->step_ = Steps::COMPLETE; state->context_->responder_->Finish(*response, state->status_, state); if (response_created) { delete response; diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 45454e09bd..86428a514e 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -833,7 +833,8 @@ class InferHandlerState { // Returns whether or not to continue cycling through the gRPC // completion queue or not. bool HandleCancellation( - InferHandlerStateType* state, bool rpc_ok, const std::string& name) + InferHandlerStateType* state, bool rpc_ok, const std::string& name, + bool is_notification) { // Check to avoid early exit in case of triton_grpc_error if (!IsCancelled()) { @@ -845,17 +846,22 @@ class InferHandlerState { << " step " << state->step_; return true; } - if ((state->step_ != Steps::CANCELLATION_ISSUED) && - (state->step_ != Steps::CANCELLED)) { + if (is_notification) { LOG_VERBOSE(1) << "Cancellation notification received for " << name << ", rpc_ok=" << rpc_ok << ", context " << state->context_->unique_id_ << ", " << state->unique_id_ << " step " << state->step_; + } + if (state->step_ != Steps::CANCELLATION_ISSUED) { // If the context has not been cancelled then // issue cancellation request to all the inflight // states belonging to the context. - if (state->context_->step_ != Steps::CANCELLED) { + // It means this is the first time we are hiting this line for this grpc + // transaction. + if ((state->step_ != Steps::CANCELLED) && + (state->context_->step_ != Steps::CANCELLED)) { + // Issue the request cancellation as it has not been cancelled yet. IssueRequestCancellation(); // Mark the context as cancelled state->context_->step_ = Steps::CANCELLED; @@ -866,26 +872,34 @@ class InferHandlerState { // next iteration from the completion queue which // would release the state. return true; + } else if (is_notification && state->step_ == Steps::CANCELLED) { + // A corner case where InferResponseComplete is called between the + // cancellation reception but before the cancellation notification + // thread enters Process function. + // Should let the InferResponseComplete callback trigger the state + // release. + LOG_VERBOSE(1) << "Waiting for the state enqueued by callback to " + "complete cancellation for " + << name << ", rpc_ok=" << rpc_ok << ", context " + << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_; + return true; + } else { + // The cancellation request has been handled so the state can be + // released. + LOG_VERBOSE(1) << "Completing cancellation for " << name + << ", rpc_ok=" << rpc_ok << ", context " + << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_; + return false; } - } - - if (state->step_ != Steps::CANCELLATION_ISSUED) { - // The cancellation has not been issued hence the state can - // be released. - LOG_VERBOSE(1) << "Completing cancellation for " << name - << ", rpc_ok=" << rpc_ok << ", context " - << state->context_->unique_id_ << ", " - << state->unique_id_ << " step " << state->step_; - - return false; - } else { - // Should wait for the ResponseComplete callbacks to be invoked. + } else { // state->step_ == Steps::CANCELLATION_ISSUED + // Should wait for the InferResponseComplete callbacks to be invoked. LOG_VERBOSE(1) << "Waiting for the callback to retrieve cancellation for " << name << ", rpc_ok=" << rpc_ok << ", context " << state->context_->unique_id_ << ", " << state->unique_id_ << " step " << state->step_; - return true; } } @@ -1061,21 +1075,14 @@ class InferHandlerState { : tritonserver_(tritonserver), async_notify_state_(false) { // For debugging and testing - const char* dstr = getenv("TRITONSERVER_DELAY_GRPC_RESPONSE"); - delay_response_ms_ = 0; - if (dstr != nullptr) { - delay_response_ms_ = atoi(dstr); - } - const char* cstr = getenv("TRITONSERVER_DELAY_GRPC_COMPLETE"); - delay_complete_ms_ = 0; - if (cstr != nullptr) { - delay_complete_ms_ = atoi(cstr); - } - const char* pstr = getenv("TRITONSERVER_DELAY_GRPC_PROCESS"); - delay_process_ms_ = 0; - if (pstr != nullptr) { - delay_process_ms_ = atoi(pstr); - } + delay_response_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_RESPONSE"); + delay_complete_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_COMPLETE"); + delay_process_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_PROCESS"); + delay_notification_process_entry_ms_ = + ParseDebugVariable("TRITONSERVER_DELAY_GRPC_NOTIFICATION"); + delay_enqueue_ms_ = ParseDebugVariable("TRITONSERVER_DELAY_GRPC_ENQUEUE"); + delay_response_completion_ms_ = + ParseDebugVariable("TRITONSERVER_DELAY_RESPONSE_COMPLETION"); response_queue_.reset(new ResponseQueue()); Reset(context, start_step); @@ -1083,6 +1090,33 @@ class InferHandlerState { ~InferHandlerState() { ClearTraceTimestamps(); } + int ParseDebugVariable(const char* env_str) + { + const char* str = getenv(env_str); + int val = 0; + if (str != nullptr) { + try { + val = std::stoi(str); + } + catch (const std::invalid_argument& e) { + LOG_ERROR << "Unable to parse the debug variable " << env_str + << ". Value provided: '" << str + << "' is not a valid integer. Error: " << e.what(); + } + catch (const std::out_of_range& e) { + LOG_ERROR << "Unable to parse the debug variable " << env_str + << ". Value provided: '" << str + << "' is out of range for an integer. Error: " << e.what(); + } + catch (const std::exception& e) { + LOG_ERROR + << "An unexpected error occurred while parsing the debug variable " + << env_str << " with value '" << str << "'. Error: " << e.what(); + } + } + return val; + } + bool IsGrpcContextCancelled() { return context_->IsCancelled(); } void Reset( @@ -1172,6 +1206,9 @@ class InferHandlerState { int delay_response_ms_; int delay_complete_ms_; int delay_process_ms_; + int delay_notification_process_entry_ms_; + int delay_enqueue_ms_; + int delay_response_completion_ms_; // For inference requests the allocator payload, unused for other // requests. @@ -1317,7 +1354,7 @@ class InferHandler : public HandlerBase { }; virtual void StartNewRequest() = 0; - virtual bool Process(State* state, bool rpc_ok) = 0; + virtual bool Process(State* state, bool rpc_ok, bool is_notification) = 0; bool ExecutePrecondition(InferHandler::State* state); TRITONSERVER_Error* ForwardHeadersAsParameters( @@ -1396,16 +1433,33 @@ InferHandler< while (cq_->Next(&tag, &ok)) { State* state = static_cast(tag); + bool is_notification = false; if (state->step_ == Steps::WAITING_NOTIFICATION) { State* state_wrapper = state; state = state_wrapper->state_ptr_; state->context_->SetReceivedNotification(true); + is_notification = true; LOG_VERBOSE(1) << "Received notification for " << Name() << ", " << state->unique_id_; + + if (state->delay_notification_process_entry_ms_ != 0) { + // Will delay the entry to Process by the specified time. + // This can be used to test the flow when + // 1. cancellation request issued for the request, which invokes + // InferResponseComplete callback right before Process. + // 2. cancellation request issued for the request during + // InferResponseComplete callback right before Process in the + // notification thread. + LOG_INFO + << "Delaying the entry to Process for notification thread by " + << state->delay_notification_process_entry_ms_ << " ms..."; + std::this_thread::sleep_for(std::chrono::milliseconds( + state->delay_notification_process_entry_ms_)); + } } LOG_VERBOSE(2) << "Grpc::CQ::Next() " << state->context_->DebugString(state); - if (!Process(state, ok)) { + if (!Process(state, ok, is_notification)) { LOG_VERBOSE(1) << "Done for " << Name() << ", " << state->unique_id_; state->context_->EraseState(state); StateRelease(state); @@ -1529,7 +1583,7 @@ class ModelInferHandler protected: void StartNewRequest() override; - bool Process(State* state, bool rpc_ok) override; + bool Process(State* state, bool rpc_ok, bool is_notification) override; private: void Execute(State* state); diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index a7dd9baa77..f809611767 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -130,7 +130,8 @@ ModelStreamInferHandler::StartNewRequest() } bool -ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) +ModelStreamInferHandler::Process( + InferHandler::State* state, bool rpc_ok, bool is_notification) { // Because gRPC doesn't allow concurrent writes on the // the stream we only have a single handler thread that @@ -143,7 +144,8 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) if (state->context_->ReceivedNotification()) { std::lock_guard lock(state->step_mtx_); if (state->IsGrpcContextCancelled()) { - bool resume = state->context_->HandleCancellation(state, rpc_ok, Name()); + bool resume = state->context_->HandleCancellation( + state, rpc_ok, Name(), is_notification); return resume; } else { if (state->context_->HandleCompletion()) { diff --git a/src/grpc/stream_infer_handler.h b/src/grpc/stream_infer_handler.h index e5163eac59..40a8346703 100644 --- a/src/grpc/stream_infer_handler.h +++ b/src/grpc/stream_infer_handler.h @@ -106,7 +106,8 @@ class ModelStreamInferHandler protected: void StartNewRequest() override; - bool Process(State* state, bool rpc_ok) override; + bool Process( + State* state, bool rpc_ok, bool is_notification = false) override; private: static void StreamInferResponseComplete(