Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: fix connection won't close if graceful shutdown starts during handshake #164

Merged
merged 6 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ type redirectResult struct {
}

const (
statusConnected int32 = iota
statusHandshaked
statusNotifyClose // notified to graceful close
statusClosing // really closing
statusActive int32 = iota
statusNotifyClose // notified to graceful close
statusClosing // really closing
statusClosed
)

Expand Down Expand Up @@ -99,7 +98,7 @@ type BackendConnManager struct {
signal unsafe.Pointer
// redirectResCh is used to notify the event receiver asynchronously.
redirectResCh chan *redirectResult
connStatus atomic.Int32
closeStatus atomic.Int32
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
backendConn *BackendConnection
Expand Down Expand Up @@ -145,7 +144,6 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
return err
}

mgr.connStatus.Store(statusHandshaked)
mgr.cmdProcessor.capability = mgr.authenticator.capability
childCtx, cancelFunc := context.WithCancel(ctx)
mgr.cancelFunc = cancelFunc
Expand Down Expand Up @@ -188,7 +186,6 @@ func (mgr *BackendConnManager) getBackendIO(ctx ConnContext, auth *Authenticator
}

auth.serverAddr = addr
mgr.connStatus.Store(statusConnected)
return mgr.backendConn.PacketIO(), nil
}

Expand All @@ -203,7 +200,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
mgr.processLock.Lock()
defer mgr.processLock.Unlock()

switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
}
Expand Down Expand Up @@ -251,7 +248,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
if err != nil && !IsMySQLError(err) {
return err
}
} else if mgr.connStatus.Load() == statusNotifyClose {
} else if mgr.closeStatus.Load() == statusNotifyClose {
mgr.tryGracefulClose(ctx, clientIO)
} else if waitingRedirect {
mgr.tryRedirect(ctx, clientIO)
Expand Down Expand Up @@ -324,7 +321,7 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pne
// tryRedirect tries to migrate the session if the session is redirect-able.
// NOTE: processLock should be held before calling this function.
func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.PacketIO) {
switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusNotifyClose, statusClosing, statusClosed:
return
}
Expand Down Expand Up @@ -404,7 +401,7 @@ func (mgr *BackendConnManager) Redirect(newAddr string) {
atomic.StorePointer(&mgr.signal, unsafe.Pointer(&signalRedirect{
newAddr: newAddr,
}))
switch mgr.connStatus.Load() {
switch mgr.closeStatus.Load() {
case statusNotifyClose, statusClosing, statusClosed:
return
}
Expand Down Expand Up @@ -443,12 +440,12 @@ func (mgr *BackendConnManager) notifyRedirectResult(ctx context.Context, rs *red

// GracefulClose waits for the end of the transaction and closes the session.
func (mgr *BackendConnManager) GracefulClose() {
mgr.connStatus.Store(statusNotifyClose)
mgr.closeStatus.Store(statusNotifyClose)
mgr.signalReceived <- signalTypeGracefulClose
}

func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *pnet.PacketIO) {
if mgr.connStatus.Load() != statusNotifyClose {
if mgr.closeStatus.Load() != statusNotifyClose {
return
}
if !mgr.cmdProcessor.finishedTxn() {
Expand All @@ -458,12 +455,12 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *p
if err := clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.SourceAddr()), zap.Error(err))
}
mgr.connStatus.Store(statusClosing)
mgr.closeStatus.Store(statusClosing)
}

// Close releases all resources.
func (mgr *BackendConnManager) Close() error {
mgr.connStatus.Store(statusClosing)
mgr.closeStatus.Store(statusClosing)
if mgr.cancelFunc != nil {
mgr.cancelFunc()
mgr.cancelFunc = nil
Expand Down Expand Up @@ -495,6 +492,6 @@ func (mgr *BackendConnManager) Close() error {
}
}
}
mgr.connStatus.Store(statusClosed)
mgr.closeStatus.Store(statusClosed)
return errors.Collect(ErrCloseConnMgr, connErr, handErr)
}
28 changes: 26 additions & 2 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (ts *backendMgrTester) redirectFail4Proxy(clientIO, backendIO *pnet.PacketI

func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error {
for i := 0; i < 30; i++ {
switch ts.mp.connStatus.Load() {
switch ts.mp.closeStatus.Load() {
case statusClosing, statusClosed:
return nil
}
Expand Down Expand Up @@ -642,7 +642,7 @@ func TestGracefulCloseWhenActive(t *testing.T) {
proxy: func(_, _ *pnet.PacketIO) error {
ts.mp.GracefulClose()
time.Sleep(300 * time.Millisecond)
require.Equal(t, statusNotifyClose, ts.mp.connStatus.Load())
require.Equal(t, statusNotifyClose, ts.mp.closeStatus.Load())
return nil
},
},
Expand All @@ -659,3 +659,27 @@ func TestGracefulCloseWhenActive(t *testing.T) {
}
ts.runTests(runners)
}

func TestGracefulCloseBeforeHandshake(t *testing.T) {
ts := newBackendMgrTester(t)
runners := []runner{
// try to gracefully close before handshake
{
proxy: func(_, _ *pnet.PacketIO) error {
ts.mp.GracefulClose()
return nil
},
},
// 1st handshake
{
client: ts.mc.authenticate,
proxy: ts.firstHandshake4Proxy,
backend: ts.handshake4Backend,
},
// it will then automatically close
{
proxy: ts.checkConnClosed,
},
}
ts.runTests(runners)
}
10 changes: 4 additions & 6 deletions pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ func (cc *ClientConnection) Run(ctx context.Context) {
}

clean:
// graceful close
if errors.Is(err, os.ErrDeadlineExceeded) {
return
}
clientErr := errors.Is(err, ErrClientConn)
if !(clientErr && errors.Is(err, io.EOF)) {
cc.logger.Info(msg, zap.Error(err), zap.Bool("clientErr", clientErr), zap.Bool("serverErr", !clientErr))
// EOF: client closes; DeadlineExceeded: graceful shutdown; Closed: shut down.
if clientErr && (errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, net.ErrClosed)) {
return
}
cc.logger.Info(msg, zap.Error(err), zap.Bool("clientErr", clientErr), zap.Bool("serverErr", !clientErr))
}

func (cc *ClientConnection) processMsg(ctx context.Context) error {
Expand Down
13 changes: 12 additions & 1 deletion pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type SQLServer struct {
hsHandler backend.HandshakeHandler
requireBackendTLS bool
wg waitgroup.WaitGroup
cancelFunc context.CancelFunc

mu serverState
}
Expand Down Expand Up @@ -88,6 +89,8 @@ func (s *SQLServer) reset(cfg *config.ProxyServerOnline) {
}

func (s *SQLServer) Run(ctx context.Context, onlineProxyConfig <-chan *config.ProxyServerOnline) {
// Create another context because it still needs to run after graceful shutdown.
ctx, s.cancelFunc = context.WithCancel(context.Background())
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -168,8 +171,12 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
// Whether this affects NLB is to be tested.
func (s *SQLServer) gracefulShutdown() {
s.mu.Lock()
s.mu.inShutdown = true
gracefulWait := s.mu.gracefulWait
if gracefulWait == 0 {
s.mu.Unlock()
return
}
s.mu.inShutdown = true
for _, conn := range s.mu.clients {
conn.GracefulClose()
}
Expand Down Expand Up @@ -197,6 +204,10 @@ func (s *SQLServer) gracefulShutdown() {
func (s *SQLServer) Close() error {
s.gracefulShutdown()

if s.cancelFunc != nil {
s.cancelFunc()
s.cancelFunc = nil
}
errs := make([]error, 0, 4)
if s.listener != nil {
errs = append(errs, s.listener.Close())
Expand Down