Skip to content

Commit

Permalink
feat: support gRPC graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Sep 19, 2024
1 parent 4e1dbe9 commit 503c19a
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pkg/remote/trans/nphttp2/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo
opts,
p.remoteService,
func(grpc.GoAwayReason) {
// do nothing
p.Clean(network, address)
},
func() {
// do nothing
Expand Down
25 changes: 18 additions & 7 deletions pkg/remote/trans/nphttp2/grpc/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -917,29 +917,40 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) {
// Notify the clientconn about the GOAWAY before we set the state to
// draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection.
if t.onGoAway != nil {
t.onGoAway(t.goAwayReason)
if t.state != draining {
if t.onGoAway != nil {
t.onGoAway(t.goAwayReason)
}
t.state = draining
}
t.state = draining
}
// All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed.
upperLimit := t.prevGoAwayID
if upperLimit == 0 { // This is the first GoAway Frame.
upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID.
}
t.prevGoAwayID = id
active := len(t.activeStreams)
if active <= 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))
return
}

var unprocessedStream []*Stream
for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server.
atomic.StoreUint32(&stream.unprocessed, 1)
unprocessedStream = append(unprocessedStream, stream)
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}
t.prevGoAwayID = id
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))

for _, stream := range unprocessedStream {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}

Expand Down
8 changes: 3 additions & 5 deletions pkg/remote/trans/nphttp2/grpc/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,8 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
if err := t.framer.WriteGoAway(sid, g.code, g.debugData); err != nil {
return false, err
}
t.framer.writer.Flush()
if g.closeConn {
// Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush()
return false, fmt.Errorf("transport: Connection closing")
}
return true, nil
Expand All @@ -1096,15 +1094,15 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
// originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process.
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
if err := t.framer.WritePing(false, goAwayPing.data); err != nil {
return false, err
}

gofunc.RecoverGoFuncWithInfo(context.Background(), func() {
timer := time.NewTimer(time.Minute)
timer := time.NewTimer(10 * time.Second)
defer timer.Stop()
select {
case <-t.drainChan:
Expand Down
49 changes: 49 additions & 0 deletions pkg/remote/trans/nphttp2/server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
opt: opt,
svcSearcher: opt.SvcSearcher,
codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)),
transports: make(map[grpcTransport.ServerTransport]struct{}),
}, nil
}

Expand All @@ -72,6 +73,10 @@ type svrTransHandler struct {
svcSearcher remote.ServiceSearcher
inkHdlFunc endpoint.Endpoint
codec remote.Codec
mu sync.Mutex
transports map[grpcTransport.ServerTransport]struct{}

hdlWG sync.WaitGroup
}

var prefaceReadAtMost = func() int {
Expand Down Expand Up @@ -121,7 +126,9 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error {
tr := svrTrans.tr

tr.HandleStreams(func(s *grpcTransport.Stream) {
t.hdlWG.Add(1)
gofunc.GoFunc(ctx, func() {
defer t.hdlWG.Done()
t.handleFunc(s, svrTrans, conn)
})
}, func(ctx context.Context, method string) context.Context {
Expand Down Expand Up @@ -315,6 +322,9 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.
if err != nil {
return nil, err
}
t.mu.Lock()
t.transports[tr] = struct{}{}
t.mu.Unlock()
pool := &sync.Pool{
New: func() interface{} {
// init rpcinfo
Expand All @@ -329,6 +339,9 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.
// 连接关闭时回调
func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr
t.mu.Lock()
delete(t.transports, tr)
t.mu.Unlock()
tr.Close()
}

Expand All @@ -349,6 +362,42 @@ func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
}

func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
t.mu.Lock()
for trans := range t.transports {
trans.Drain()
}
t.mu.Unlock()

exitCh := make(chan struct{})
// todo: think about a better grace time duration
graceTime := time.Minute * 3
exitTimeout, ok := ctx.Deadline()
if ok {
graceTime = time.Until(exitTimeout)
}
graceTimer := time.NewTimer(graceTime)
gofunc.GoFunc(ctx, func() {
select {
case <-graceTimer.C:
t.mu.Lock()
for trans := range t.transports {
// use CloseWithErr
trans.Close()
}
t.mu.Unlock()
return
case <-exitCh:
return
}
})

t.hdlWG.Wait()
close(exitCh)

return nil
}

func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
c := t.opt.TracerCtl.DoStart(ctx, ri)
return c
Expand Down

0 comments on commit 503c19a

Please sign in to comment.