Skip to content

Commit

Permalink
backend: fix connection won't close if graceful shutdown starts durin…
Browse files Browse the repository at this point in the history
…g handshake (pingcap#164)
  • Loading branch information
djshow832 authored and xhebox committed Mar 13, 2023
1 parent e6e473d commit 1bb6981
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 25 deletions.
29 changes: 13 additions & 16 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ type redirectResult struct {
}

const (
statusConnected int32 = iota
statusHandshaked
statusNotifyClose // notified to graceful close
statusClosing // really closing
statusActive int32 = iota
statusNotifyClose // notified to graceful close
statusClosing // really closing
statusClosed
)

Expand Down Expand Up @@ -99,7 +98,7 @@ type BackendConnManager struct {
signal unsafe.Pointer
// redirectResCh is used to notify the event receiver asynchronously.
redirectResCh chan *redirectResult
connStatus atomic.Int32
closeStatus atomic.Int32
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
backendConn *BackendConnection
Expand Down Expand Up @@ -145,7 +144,6 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
return err
}

mgr.connStatus.Store(statusHandshaked)
mgr.cmdProcessor.capability = mgr.authenticator.capability
childCtx, cancelFunc := context.WithCancel(ctx)
mgr.cancelFunc = cancelFunc
Expand Down Expand Up @@ -188,7 +186,6 @@ func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator
}

auth.serverAddr = addr
mgr.connStatus.Store(statusConnected)
return mgr.backendConn.PacketIO(), nil
}

Expand All @@ -203,7 +200,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
mgr.processLock.Lock()
defer mgr.processLock.Unlock()

switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
}
Expand Down Expand Up @@ -251,7 +248,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
if err != nil && !IsMySQLError(err) {
return err
}
} else if mgr.connStatus.Load() == statusNotifyClose {
} else if mgr.closeStatus.Load() == statusNotifyClose {
mgr.tryGracefulClose(ctx, clientIO)
} else if waitingRedirect {
mgr.tryRedirect(ctx, clientIO)
Expand Down Expand Up @@ -324,7 +321,7 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pne
// tryRedirect tries to migrate the session if the session is redirect-able.
// NOTE: processLock should be held before calling this function.
func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.PacketIO) {
switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusNotifyClose, statusClosing, statusClosed:
return
}
Expand Down Expand Up @@ -404,7 +401,7 @@ func (mgr *BackendConnManager) Redirect(newAddr string) {
atomic.StorePointer(&mgr.signal, unsafe.Pointer(&signalRedirect{
newAddr: newAddr,
}))
switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusNotifyClose, statusClosing, statusClosed:
return
}
Expand Down Expand Up @@ -443,12 +440,12 @@ func (mgr *BackendConnManager) notifyRedirectResult(ctx context.Context, rs *red

// GracefulClose waits for the end of the transaction and closes the session.
func (mgr *BackendConnManager) GracefulClose() {
mgr.connStatus.Store(statusNotifyClose)
mgr.closeStatus.Store(statusNotifyClose)
mgr.signalReceived <- signalTypeGracefulClose
}

func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *pnet.PacketIO) {
if mgr.connStatus.Load() != statusNotifyClose {
if mgr.closeStatus.Load() != statusNotifyClose {
return
}
if !mgr.cmdProcessor.finishedTxn() {
Expand All @@ -458,12 +455,12 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *p
if err := clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.SourceAddr()), zap.Error(err))
}
mgr.connStatus.Store(statusClosing)
mgr.closeStatus.Store(statusClosing)
}

// Close releases all resources.
func (mgr *BackendConnManager) Close() error {
mgr.connStatus.Store(statusClosing)
mgr.closeStatus.Store(statusClosing)
if mgr.cancelFunc != nil {
mgr.cancelFunc()
mgr.cancelFunc = nil
Expand Down Expand Up @@ -495,6 +492,6 @@ func (mgr *BackendConnManager) Close() error {
}
}
}
mgr.connStatus.Store(statusClosed)
mgr.closeStatus.Store(statusClosed)
return errors.Collect(ErrCloseConnMgr, connErr, handErr)
}
28 changes: 26 additions & 2 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketI

func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error {
for i := 0; i < 30; i++ {
switch ts.mp.connStatus.Load() {
switch ts.mp.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
}
Expand Down Expand Up @@ -642,7 +642,7 @@ func TestGracefulCloseWhenActive(t *testing.T) {
proxy: func(_, _ *pnet.PacketIO) error {
ts.mp.GracefulClose()
time.Sleep(300 * time.Millisecond)
require.Equal(t, statusNotifyClose, ts.mp.connStatus.Load())
require.Equal(t, statusNotifyClose, ts.mp.closeStatus.Load())
return nil
},
},
Expand All @@ -659,3 +659,27 @@ func TestGracefulCloseWhenActive(t *testing.T) {
}
ts.runTests(runners)
}

func TestGracefulCloseBeforeHandshake(t *testing.T) {
ts := newBackendMgrTester(t)
runners := []runner{
// try to gracefully close before handshake
{
proxy: func(_, _ *pnet.PacketIO) error {
ts.mp.GracefulClose()
return nil
},
},
// 1st handshake
{
client: ts.mc.authenticate,
proxy: ts.firstHandshake4Proxy,
backend: ts.handshake4Backend,
},
// it will then automatically close
{
proxy: ts.checkConnClosed,
},
}
ts.runTests(runners)
}
10 changes: 4 additions & 6 deletions pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ func (cc *ClientConnection) Run(ctx context.Context) {
}

clean:
// graceful close
if errors.Is(err, os.ErrDeadlineExceeded) {
return
}
clientErr := errors.Is(err, ErrClientConn)
if !(clientErr && errors.Is(err, io.EOF)) {
cc.logger.Info(msg, zap.Error(err), zap.Bool("clientErr", clientErr), zap.Bool("serverErr", !clientErr))
// EOF: client closes; DeadlineExceeded: graceful shutdown; Closed: shut down.
if clientErr && (errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, net.ErrClosed)) {
return
}
cc.logger.Info(msg, zap.Error(err), zap.Bool("clientErr", clientErr), zap.Bool("serverErr", !clientErr))
}

func (cc *ClientConnection) processMsg(ctx context.Context) error {
Expand Down
13 changes: 12 additions & 1 deletion pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type SQLServer struct {
hsHandler backend.HandshakeHandler
requireBackendTLS bool
wg waitgroup.WaitGroup
cancelFunc context.CancelFunc

mu serverState
}
Expand Down Expand Up @@ -88,6 +89,8 @@ func (s *SQLServer) reset(cfg *config.ProxyServerOnline) {
}

func (s *SQLServer) Run(ctx context.Context, onlineProxyConfig <-chan *config.ProxyServerOnline) {
// Create another context because it still needs to run after graceful shutdown.
ctx, s.cancelFunc = context.WithCancel(context.Background())
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -168,8 +171,12 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
// Whether this affects NLB is to be tested.
func (s *SQLServer) gracefulShutdown() {
s.mu.Lock()
s.mu.inShutdown = true
gracefulWait := s.mu.gracefulWait
if gracefulWait == 0 {
s.mu.Unlock()
return
}
s.mu.inShutdown = true
for _, conn := range s.mu.clients {
conn.GracefulClose()
}
Expand Down Expand Up @@ -197,6 +204,10 @@ func (s *SQLServer) gracefulShutdown() {
func (s *SQLServer) Close() error {
s.gracefulShutdown()

if s.cancelFunc != nil {
s.cancelFunc()
s.cancelFunc = nil
}
errs := make([]error, 0, 4)
if s.listener != nil {
errs = append(errs, s.listener.Close())
Expand Down

0 comments on commit 1bb6981

Please sign in to comment.