diff --git a/dbms/src/Flash/EstablishCall.cpp b/dbms/src/Flash/EstablishCall.cpp index 5a6486feaed..b32928e9b8b 100644 --- a/dbms/src/Flash/EstablishCall.cpp +++ b/dbms/src/Flash/EstablishCall.cpp @@ -4,10 +4,11 @@ namespace DB { -EstablishCallData::EstablishCallData(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq) +EstablishCallData::EstablishCallData(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq, const std::shared_ptr> & is_shutdown) : service(service) , cq(cq) , notify_cq(notify_cq) + , is_shutdown(is_shutdown) , responder(&ctx) , state(NEW_REQUEST) { @@ -17,9 +18,9 @@ EstablishCallData::EstablishCallData(AsyncFlashService * service, grpc::ServerCo service->RequestEstablishMPPConnection(&ctx, &request, &responder, cq, notify_cq, this); } -EstablishCallData * EstablishCallData::spawn(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq) +EstablishCallData * EstablishCallData::spawn(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq, const std::shared_ptr> & is_shutdown) { - return new EstablishCallData(service, cq, notify_cq); + return new EstablishCallData(service, cq, notify_cq, is_shutdown); } void EstablishCallData::tryFlushOne() @@ -36,6 +37,12 @@ void EstablishCallData::tryFlushOne() mpp_tunnel->sendJob(false); } +void EstablishCallData::responderFinish(const grpc::Status & status) +{ + if (!(*is_shutdown)) + responder.Finish(status, this); +} + void EstablishCallData::initRpc() { std::exception_ptr eptr = nullptr; @@ -51,12 +58,14 @@ void EstablishCallData::initRpc() { state = FINISH; grpc::Status status(static_cast(GRPC_STATUS_UNKNOWN), getExceptionMessage(eptr, false)); - responder.Finish(status, this); + responderFinish(status); } } bool EstablishCallData::write(const mpp::MPPDataPacket & packet) { + if (*is_shutdown) + return false; responder.Write(packet, this); return true; } @@ -77,7 +86,7 @@ void EstablishCallData::writeDone(const ::grpc::Status & status) { LOG_FMT_INFO(mpp_tunnel->getLogger(), "connection for {} cost {} ms.", mpp_tunnel->id(), stopwatch->elapsedMilliseconds()); } - responder.Finish(status, this); + responderFinish(status); } void EstablishCallData::notifyReady() @@ -93,10 +102,11 @@ void EstablishCallData::cancel() delete this; return; } + state = FINISH; if (mpp_tunnel) mpp_tunnel->consumerFinish("grpc writes failed.", true); //trigger mpp tunnel finish work grpc::Status status(static_cast(GRPC_STATUS_UNKNOWN), "Consumer exits unexpected, grpc writes failed."); - responder.Finish(status, this); + responderFinish(status); } void EstablishCallData::proceed() @@ -105,7 +115,7 @@ void EstablishCallData::proceed() { state = PROCESSING; - spawn(service, cq, notify_cq); + spawn(service, cq, notify_cq, is_shutdown); notifyReady(); initRpc(); } diff --git a/dbms/src/Flash/EstablishCall.h b/dbms/src/Flash/EstablishCall.h index 288966ee006..f54e2015e40 100644 --- a/dbms/src/Flash/EstablishCall.h +++ b/dbms/src/Flash/EstablishCall.h @@ -31,7 +31,11 @@ class EstablishCallData : public PacketWriter // it reacts base on current state. The completion queue "cq" and "notify_cq" // used for asynchronous communication with the gRPC runtime. // "notify_cq" gets the tag back indicating a call has started. All subsequent operations (reads, writes, etc) on that call report back to "cq". - EstablishCallData(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq); + EstablishCallData( + AsyncFlashService * service, + grpc::ServerCompletionQueue * cq, + grpc::ServerCompletionQueue * notify_cq, + const std::shared_ptr> & is_shutdown); bool write(const mpp::MPPDataPacket & packet) override; @@ -50,13 +54,19 @@ class EstablishCallData : public PacketWriter // Spawn a new EstablishCallData instance to serve new clients while we process the one for this EstablishCallData. // The instance will deallocate itself as part of its FINISH state. // EstablishCallData will handle its lifecycle by itself. - static EstablishCallData * spawn(AsyncFlashService * service, grpc::ServerCompletionQueue * cq, grpc::ServerCompletionQueue * notify_cq); + static EstablishCallData * spawn( + AsyncFlashService * service, + grpc::ServerCompletionQueue * cq, + grpc::ServerCompletionQueue * notify_cq, + const std::shared_ptr> & is_shutdown); private: void notifyReady(); void initRpc(); + void responderFinish(const grpc::Status & status); + std::mutex mu; // server instance AsyncFlashService * service; @@ -64,6 +74,7 @@ class EstablishCallData : public PacketWriter // The producer-consumer queue where for asynchronous server notifications. ::grpc::ServerCompletionQueue * cq; ::grpc::ServerCompletionQueue * notify_cq; + std::shared_ptr> is_shutdown; ::grpc::ServerContext ctx; ::grpc::Status err_status; diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index e2b1345e52b..c80538bd425 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -554,6 +554,7 @@ class Server::FlashGrpcServerHolder public: FlashGrpcServerHolder(Server & server, const TiFlashRaftConfig & raft_config, Poco::Logger * log_) : log(log_) + , is_shutdown(std::make_shared>(false)) { grpc::ServerBuilder builder; if (server.security_config.has_tls_config) @@ -617,7 +618,7 @@ class Server::FlashGrpcServerHolder for (int j = 0; j < preallocated_request_count_per_poller; ++j) { // EstablishCallData will handle its lifecycle by itself. - EstablishCallData::spawn(assert_cast(flash_service.get()), cq, notify_cq); + EstablishCallData::spawn(assert_cast(flash_service.get()), cq, notify_cq, is_shutdown); } thread_manager->schedule(false, "async_poller", [cq, this] { handleRpcs(cq, log); }); thread_manager->schedule(false, "async_poller", [notify_cq, this] { handleRpcs(notify_cq, log); }); @@ -627,6 +628,9 @@ class Server::FlashGrpcServerHolder ~FlashGrpcServerHolder() { + *is_shutdown = true; + const int wait_calldata_after_shutdown_interval_ms = 500; + std::this_thread::sleep_for(std::chrono::milliseconds(wait_calldata_after_shutdown_interval_ms)); // sleep 500ms to let operations of calldata called by MPPTunnel done. /// Shut down grpc server. // wait 5 seconds for pending rpcs to gracefully stop gpr_timespec deadline{5, 0, GPR_TIMESPAN}; @@ -649,6 +653,7 @@ class Server::FlashGrpcServerHolder private: Poco::Logger * log; + std::shared_ptr> is_shutdown; std::unique_ptr flash_service = nullptr; std::unique_ptr diagnostics_service = nullptr; std::unique_ptr flash_grpc_server = nullptr;