Skip to content

Commit

Permalink
backend: send the error to the client when the handler encounters an …
Browse files Browse the repository at this point in the history
…error (#187)
  • Loading branch information
djshow832 authored Jan 16, 2023
1 parent 8dbe706 commit 8288910
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 28 deletions.
20 changes: 13 additions & 7 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte

clientResp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return err
return WrapUserError(err, err.Error())
}
auth.user = clientResp.User
auth.dbname = clientResp.DB
Expand All @@ -156,23 +156,29 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp, 5*time.Second)
if err != nil {
return err
return WrapUserError(err, connectErrMsg)
}
backendIO.ResetSequence()

// write proxy header
if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil {
return err
return WrapUserError(err, handshakeErrMsg)
}

// read backend initial handshake
_, backendCapability, err := auth.readInitialHandshake(backendIO)
serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
if IsMySQLError(err) {
if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil {
err = writeErr
}
return err
}
return WrapUserError(err, handshakeErrMsg)
}

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return err
return WrapUserError(err, capabilityErrMsg)
}

if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
Expand All @@ -193,7 +199,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
// send an unknown auth plugin so that the backend will request the auth data again.
unknownAuthPlugin, nil, 0,
); err != nil {
return err
return WrapUserError(err, handshakeErrMsg)
}

// forward other packets
Expand Down
34 changes: 20 additions & 14 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err)
if err != nil {
WriteUserError(clientIO, err, mgr.logger)
return err
}

Expand All @@ -166,39 +167,38 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) {
r, err := mgr.handshakeHandler.GetRouter(cctx, resp)
if err != nil {
return nil, err
return nil, WrapUserError(err, err.Error())
}
// Reasons to wait:
// - The TiDB instances may not be initialized yet
// - One TiDB may be just shut down and another is just started but not ready yet
bctx, cancel := context.WithTimeout(context.Background(), timeout)
selector := r.GetBackendSelector()
var addr string
var origErr error
io, err := backoff.RetryNotifyWithData(
func() (*pnet.PacketIO, error) {
// Try to connect to all backup backends one by one.
addr, err := selector.Next()
addr, err = selector.Next()
// If all addrs are enumerated, reset and try again.
if err == nil && addr == "" {
selector.Reset()
addr, err = selector.Next()
}
if err != nil {
return nil, backoff.Permanent(err)
return nil, backoff.Permanent(WrapUserError(err, err.Error()))
}

// if all addrs are enumerated, reset and try again
if addr == "" {
selector.Reset()
if addr, err = selector.Next(); err != nil {
return nil, backoff.Permanent(err)
}
if addr == "" {
return nil, router.ErrNoInstanceToSelect
}
return nil, router.ErrNoInstanceToSelect
}

cn, err := net.DialTimeout("tcp", addr, DialTimeout)
var cn net.Conn
cn, err = net.DialTimeout("tcp", addr, DialTimeout)
if err != nil {
return nil, errors.Wrapf(err, "dial backend %s error", addr)
}

if err := selector.Succeed(mgr); err != nil {
if err = selector.Succeed(mgr); err != nil {
// Bad luck: the backend has been recycled or shut down just after the selector returns it.
if ignoredErr := cn.Close(); ignoredErr != nil {
mgr.logger.Error("close backend connection failed", zap.String("addr", addr), zap.Error(ignoredErr))
Expand All @@ -215,10 +215,16 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
},
backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx),
func(err error, d time.Duration) {
origErr = err
mgr.handshakeHandler.OnHandshake(cctx, addr, err)
},
)
cancel()
if err != nil && errors.Is(err, context.DeadlineExceeded) {
if origErr != nil {
err = origErr
}
}
return io, err
}

Expand Down
53 changes: 52 additions & 1 deletion pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,57 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) {
ts.runTests(runners)
}

func TestHandlerReturnError(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
}{
{
cfg: func(config *testConfig) {
config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
return nil, errors.New("mocked error")
}
},
errMsg: "mocked error",
},
{
cfg: func(config *testConfig) {
config.proxyConfig.handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error {
return errors.New("mocked error")
}
},
errMsg: "mocked error",
},
{
// TODO: make it fail faster.
cfg: func(config *testConfig) {
config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
return router.NewStaticRouter(nil), nil
}
},
errMsg: connectErrMsg,
},
}
for _, test := range tests {
ts := newBackendMgrTester(t, test.cfg)
rn := runner{
client: func(packetIO *pnet.PacketIO) error {
err := ts.mc.authenticate(packetIO)
require.NoError(t, err)
require.ErrorContains(t, ts.mc.mysqlErr, test.errMsg)
return nil
},
proxy: func(clientIO, backendIO *pnet.PacketIO) error {
err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig)
require.Error(t, err)
return nil
},
backend: nil,
}
ts.runAndCheck(ts.t, func(t *testing.T, ts *testSuite) {}, rn.client, rn.backend, rn.proxy)
}
}

func TestGetBackendIO(t *testing.T) {
addrs := make([]string, 0, 3)
listeners := make([]net.Listener, 0, cap(addrs))
Expand Down Expand Up @@ -732,7 +783,7 @@ func TestGetBackendIO(t *testing.T) {
err = listeners[i].Close()
require.NoError(t, err, message)
} else {
require.ErrorIs(t, err, context.DeadlineExceeded, message)
require.Error(t, err, message)
}
require.True(t, len(badAddrs) <= i, message)
badAddrs = make(map[string]struct{}, 3)
Expand Down
5 changes: 3 additions & 2 deletions pkg/proxy/backend/cmd_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,11 +1036,12 @@ func TestNetworkError(t *testing.T) {
}
clientErrChecker := func(t *testing.T, ts *testSuite) {
require.True(t, pnet.IsDisconnectError(ts.mp.err))
require.True(t, pnet.IsDisconnectError(ts.mp.err))
require.True(t, pnet.IsDisconnectError(ts.mc.err))
require.NotNil(t, ts.mp.err.(*UserError))
}
backendErrChecker := func(t *testing.T, ts *testSuite) {
require.True(t, pnet.IsDisconnectError(ts.mp.err))
require.True(t, pnet.IsDisconnectError(ts.mp.err))
require.True(t, pnet.IsDisconnectError(ts.mb.err))
}
proxyErrChecker := func(t *testing.T, ts *testSuite) {
require.True(t, pnet.IsDisconnectError(ts.mp.err))
Expand Down
4 changes: 3 additions & 1 deletion pkg/proxy/backend/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() {
if tc.proxyBIO != nil {
_ = tc.proxyBIO.Close()
}
_ = tc.backendIO.Close()
if tc.backendIO != nil {
_ = tc.backendIO.Close()
}
}
}

Expand Down
75 changes: 75 additions & 0 deletions pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
"github.com/pingcap/TiProxy/lib/util/errors"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/parser/mysql"
"go.uber.org/zap"
)

const (
connectErrMsg = "No available TiDB instances, please check TiDB cluster"
handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network"
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
)

// UserError is returned to the client.
// err is used to log and userMsg is used to report to the user.
type UserError struct {
err error
userMsg string
}

func WrapUserError(err error, userMsg string) *UserError {
if err == nil {
return nil
}
if ue, ok := err.(*UserError); ok {
return ue
}
return &UserError{
err: err,
userMsg: userMsg,
}
}

func (ue *UserError) UserMsg() string {
return ue.userMsg
}

func (ue *UserError) Unwrap() error {
return ue.err
}

func (ue *UserError) Error() string {
return ue.err.Error()
}

// WriteUserError writes an unknown error to the client.
func WriteUserError(clientIO *pnet.PacketIO, err error, lg *zap.Logger) {
if err == nil {
return
}
var ue *UserError
if !errors.As(err, &ue) {
return
}
myErr := mysql.NewErrf(mysql.ErrUnknown, "%s", nil, ue.UserMsg())
if writeErr := clientIO.WriteErrPacket(myErr); writeErr != nil {
lg.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr))
}
}
9 changes: 8 additions & 1 deletion pkg/proxy/backend/mock_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type mockClient struct {
*clientConfig
// Outputs that received from the server and will be checked by the test.
authSucceed bool
mysqlErr error
}

func newMockClient(cfg *clientConfig) *mockClient {
Expand Down Expand Up @@ -117,6 +118,7 @@ func (mc *mockClient) writePassword(packetIO *pnet.PacketIO) error {
return nil
case mysql.ErrHeader:
mc.authSucceed = false
mc.mysqlErr = pnet.ParseErrorPacket(serverPkt)
return nil
case mysql.AuthSwitchRequest, pnet.ShaCommand:
if err := packetIO.WritePacket(mc.authData, true); err != nil {
Expand Down Expand Up @@ -182,7 +184,10 @@ func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error {
return err
}
switch resp[0] {
case mysql.OKHeader, mysql.ErrHeader:
case mysql.OKHeader:
return nil
case mysql.ErrHeader:
mc.mysqlErr = pnet.ParseErrorPacket(resp)
return nil
default:
if err := packetIO.WritePacket(mc.authData, true); err != nil {
Expand Down Expand Up @@ -268,6 +273,7 @@ func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, e
return
}
if pkt[0] == mysql.ErrHeader {
mc.mysqlErr = pnet.ParseErrorPacket(pkt)
return
}
if mc.capability&pnet.ClientDeprecateEOF == 0 {
Expand Down Expand Up @@ -311,6 +317,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error {
case mysql.OKHeader:
serverStatus = binary.LittleEndian.Uint16(pkt[3:])
case mysql.ErrHeader:
mc.mysqlErr = pnet.ParseErrorPacket(pkt)
return nil
case mysql.LocalInFileHeader:
for i := 0; i < mc.filePkts; i++ {
Expand Down
5 changes: 3 additions & 2 deletions pkg/proxy/backend/testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ type checker func(t *testing.T, ts *testSuite)

func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (*testSuite, func()) {
ts := &testSuite{}
cfg := newTestConfig(append(overriders, func(config *testConfig) {
overriders = append([]cfgOverrider{func(config *testConfig) {
config.backendConfig.tlsConfig = tc.backendTLSConfig
config.proxyConfig.backendTLSConfig = tc.clientTLSConfig
config.proxyConfig.frontendTLSConfig = tc.backendTLSConfig
config.clientConfig.tlsConfig = tc.clientTLSConfig
config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
return router.NewStaticRouter([]string{ts.tc.backendListener.Addr().String()}), nil
}
})...)
}}, overriders...)
cfg := newTestConfig(overriders...)
ts.mb = newMockBackend(cfg.backendConfig)
ts.mp = newMockProxy(t, cfg.proxyConfig)
ts.mc = newMockClient(cfg.clientConfig)
Expand Down

0 comments on commit 8288910

Please sign in to comment.