From 67b4624cb072c0c67696dd1fa41c1a4744503d6a Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Thu, 19 Sep 2024 09:19:18 +0800 Subject: [PATCH] feat: support gRPC graceful shutdown --- pkg/remote/trans/nphttp2/conn_pool.go | 2 +- .../nphttp2/grpc/graceful_shutdown_test.go | 52 ++++++++++++ pkg/remote/trans/nphttp2/grpc/http2_client.go | 30 +++++-- pkg/remote/trans/nphttp2/grpc/http2_server.go | 8 +- .../trans/nphttp2/grpc/transport_test.go | 82 +++++++++++++++++-- pkg/remote/trans/nphttp2/server_handler.go | 54 +++++++++++- 6 files changed, 206 insertions(+), 22 deletions(-) create mode 100644 pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go diff --git a/pkg/remote/trans/nphttp2/conn_pool.go b/pkg/remote/trans/nphttp2/conn_pool.go index cdd1a2f785..14c422f123 100644 --- a/pkg/remote/trans/nphttp2/conn_pool.go +++ b/pkg/remote/trans/nphttp2/conn_pool.go @@ -121,7 +121,7 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo opts, p.remoteService, func(grpc.GoAwayReason) { - // do nothing + p.Clean(network, address) }, func() { // do nothing diff --git a/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go new file mode 100644 index 0000000000..884aa52503 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go @@ -0,0 +1,52 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "math" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestGracefulShutdown(t *testing.T) { + srv, cli := setUp(t, 0, math.MaxUint32, gracefulShutdown) + defer cli.Close(errSelfCloseForTest) + + stream, err := cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, err == nil, err) + <-srv.srvReady + go srv.gracefulShutdown() + err = cli.Write(stream, nil, []byte("hello"), &Options{}) + test.Assert(t, err == nil, err) + msg := make([]byte, 5) + num, err := stream.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + _, err = cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, err != nil, err) + t.Logf("NewStream err: %v", err) + time.Sleep(1 * time.Second) + err = cli.Write(stream, nil, []byte("hello"), &Options{}) + test.Assert(t, err != nil, err) + t.Logf("After timeout, Write err: %v", err) + _, err = stream.Read(msg) + test.Assert(t, err != nil, err) + t.Logf("After timeout, Read err: %v", err) +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..9e46c30a68 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -468,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea s.id = h.streamID s.fc = &inFlow{limit: uint32(t.initialWindowSize)} t.mu.Lock() - if t.activeStreams == nil { // Can be niled from Close(). + // Don't create a stream if the transport is in a state of graceful shutdown or already closed + if t.state == draining || t.activeStreams == nil { // Can be niled from Close(). t.mu.Unlock() - return false // Don't create a stream if the transport is already closed. + return false } t.activeStreams[s.id] = s t.mu.Unlock() @@ -917,10 +918,12 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // Notify the clientconn about the GOAWAY before we set the state to // draining, to allow the client to stop attempting to create streams // before disallowing new streams on this connection. - if t.onGoAway != nil { - t.onGoAway(t.goAwayReason) + if t.state != draining { + if t.onGoAway != nil { + t.onGoAway(t.goAwayReason) + } + t.state = draining } - t.state = draining } // All streams with IDs greater than the GoAwayId // and smaller than the previous GoAway ID should be killed. @@ -928,18 +931,27 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { if upperLimit == 0 { // This is the first GoAway Frame. upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. } + t.prevGoAwayID = id + active := len(t.activeStreams) + if active <= 0 { + t.mu.Unlock() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + return + } + + var unprocessedStream []*Stream for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) + unprocessedStream = append(unprocessedStream, stream) t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } - t.prevGoAwayID = id - active := len(t.activeStreams) t.mu.Unlock() - if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + + for _, stream := range unprocessedStream { + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..ee5283716d 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -1081,10 +1081,8 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { if err := t.framer.WriteGoAway(sid, g.code, g.debugData); err != nil { return false, err } + t.framer.writer.Flush() if g.closeConn { - // Abruptly close the connection following the GoAway (via - // loopywriter). But flush out what's inside the buffer first. - t.framer.writer.Flush() return false, fmt.Errorf("transport: Connection closing") } return true, nil @@ -1096,7 +1094,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { // originated before the GoAway reaches the client. // After getting the ack or timer expiration send out another GoAway this // time with an ID of the max stream server intends to process. - if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil { return false, err } if err := t.framer.WritePing(false, goAwayPing.data); err != nil { @@ -1104,7 +1102,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { } gofunc.RecoverGoFuncWithInfo(context.Background(), func() { - timer := time.NewTimer(time.Minute) + timer := time.NewTimer(10 * time.Second) defer timer.Stop() select { case <-t.drainChan: diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index ec98259a52..78b5ebf8e5 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -57,6 +57,10 @@ type server struct { conns map[ServerTransport]bool h *testStreamHandler ready chan struct{} + hdlWG sync.WaitGroup + transWG sync.WaitGroup + + srvReady chan struct{} } var ( @@ -77,6 +81,7 @@ func init() { type testStreamHandler struct { t *http2Server + srv *server notify chan struct{} getNotified chan struct{} } @@ -92,6 +97,8 @@ const ( invalidHeaderField delayRead pingpong + + gracefulShutdown ) func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { @@ -292,6 +299,20 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { } } +func (h *testStreamHandler) gracefulShutdown(t *testing.T, s *Stream) { + close(h.srv.srvReady) + msg := make([]byte, 5) + num, err := s.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + test.Assert(t, string(msg) == "hello", string(msg)) + err = h.t.Write(s, nil, msg, &Options{}) + test.Assert(t, err == nil, err) + _, err = s.Read(msg) + test.Assert(t, err != nil, err) + t.Logf("Server-side after timeout err: %v", err) +} + // start starts server. Other goroutines should block on s.readyChan for further operations. func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { // 创建 listener @@ -329,6 +350,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.conns[transport] = true h := &testStreamHandler{t: transport.(*http2Server)} s.h = h + h.srv = s s.mu.Unlock() switch ht { case notifyCall: @@ -379,12 +401,26 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT }, func(ctx context.Context, method string) context.Context { return ctx }) + case gracefulShutdown: + s.transWG.Add(1) + go func() { + defer s.transWG.Done() + transport.HandleStreams(func(stream *Stream) { + s.hdlWG.Add(1) + go func() { + defer s.hdlWG.Done() + h.gracefulShutdown(t, stream) + }() + }, func(ctx context.Context, method string) context.Context { return ctx }) + }() default: - go transport.HandleStreams(func(s *Stream) { - go h.handleStream(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx - }) + go func() { + transport.HandleStreams(func(s *Stream) { + go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx + }) + }() } return ctx } @@ -434,6 +470,40 @@ func (s *server) stop() { s.mu.Unlock() } +func (s *server) gracefulShutdown() { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + s.lis.Close() + s.mu.Lock() + for trans := range s.conns { + trans.Drain() + } + s.mu.Unlock() + timeout, _ := ctx.Deadline() + graceTimer := time.NewTimer(time.Until(timeout)) + exitCh := make(chan struct{}) + go func() { + select { + case <-graceTimer.C: + s.mu.Lock() + for trans := range s.conns { + trans.Close() + } + s.mu.Unlock() + return + case <-exitCh: + return + } + }() + s.hdlWG.Wait() + s.transWG.Wait() + close(exitCh) + s.conns = nil + if err := s.eventLoop.Shutdown(ctx); err != nil { + fmt.Printf("netpoll server exit failed, err=%v", err) + } +} + func (s *server) addr() string { if s.lis == nil { return "" @@ -442,7 +512,7 @@ func (s *server) addr() string { } func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { - server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} + server := &server{startedErr: make(chan error, 1), ready: make(chan struct{}), srvReady: make(chan struct{})} go server.start(t, port, serverConfig, ht) server.wait(t, time.Second) return server diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index f249f84242..e22f5652a2 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -62,6 +62,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { opt: opt, svcSearcher: opt.SvcSearcher, codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), + transports: make(map[grpcTransport.ServerTransport]struct{}), }, nil } @@ -72,6 +73,11 @@ type svrTransHandler struct { svcSearcher remote.ServiceSearcher inkHdlFunc endpoint.Endpoint codec remote.Codec + mu sync.Mutex + transports map[grpcTransport.ServerTransport]struct{} + + hdlWG sync.WaitGroup + transWG sync.WaitGroup } var prefaceReadAtMost = func() int { @@ -119,9 +125,11 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) tr := svrTrans.tr - + defer t.transWG.Done() tr.HandleStreams(func(s *grpcTransport.Stream) { + t.hdlWG.Add(1) gofunc.GoFunc(ctx, func() { + defer t.hdlWG.Done() t.handleFunc(s, svrTrans, conn) }) }, func(ctx context.Context, method string) context.Context { @@ -315,6 +323,10 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. if err != nil { return nil, err } + t.transWG.Add(1) + t.mu.Lock() + t.transports[tr] = struct{}{} + t.mu.Unlock() pool := &sync.Pool{ New: func() interface{} { // init rpcinfo @@ -329,6 +341,9 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. // 连接关闭时回调 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr + t.mu.Lock() + delete(t.transports, tr) + t.mu.Unlock() tr.Close() } @@ -349,6 +364,43 @@ func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { } +func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { + t.mu.Lock() + for trans := range t.transports { + trans.Drain() + } + t.mu.Unlock() + + exitCh := make(chan struct{}) + // todo: think about a better grace time duration + graceTime := time.Minute * 3 + exitTimeout, ok := ctx.Deadline() + if ok { + graceTime = time.Until(exitTimeout) + } + graceTimer := time.NewTimer(graceTime) + gofunc.GoFunc(ctx, func() { + select { + case <-graceTimer.C: + t.mu.Lock() + for trans := range t.transports { + // use CloseWithErr + trans.Close() + } + t.mu.Unlock() + return + case <-exitCh: + return + } + }) + + t.hdlWG.Wait() + t.transWG.Wait() + close(exitCh) + + return nil +} + func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { c := t.opt.TracerCtl.DoStart(ctx, ri) return c