Skip to content

Commit

Permalink
backend: add tests for network error (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Aug 14, 2022
1 parent 930663c commit 4bf1155
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 149 deletions.
66 changes: 27 additions & 39 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ import (
"fmt"

pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/errors"
"github.com/pingcap/TiProxy/pkg/util/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)

// Authenticator handshakes with the client and the backend.
Expand All @@ -42,42 +40,36 @@ func (auth *Authenticator) String() string {
auth.user, auth.dbname, auth.capability, auth.collation)
}

func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO, serverTLSConfig, backendTLSConfig *tls.Config) (bool, error) {
func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO, serverTLSConfig, backendTLSConfig *tls.Config) error {
backendIO.ResetSequence()
var (
serverPkt, clientPkt []byte
err error
serverCapability uint32
)

// Read initial handshake packet from the backend.
serverPkt, serverCapability, err = auth.readInitialHandshake(backendIO)
serverPkt, serverCapability, err := auth.readInitialHandshake(backendIO)
if serverPkt != nil {
writeErr := clientIO.WritePacket(serverPkt, true)
if writeErr != nil {
return false, writeErr
}
if err != nil {
return false, nil
if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil {
return writeErr
}
} else {
return false, err
}
if err != nil {
return err
}
if serverCapability&mysql.ClientSSL == 0 {
return false, errors.New("the TiDB server must enable TLS")
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
return errors.New("the TiDB server must enable TLS")
}

// Read the response from the client.
if clientPkt, err = clientIO.ReadPacket(); err != nil {
return false, err
clientPkt, err := clientIO.ReadPacket()
if err != nil {
return err
}
capability := binary.LittleEndian.Uint16(clientPkt[:2])
// A 2-bytes capability contains the ClientSSL flag, no matter ClientProtocol41 is set or not.
sslEnabled := uint32(capability)&mysql.ClientSSL > 0
if sslEnabled {
// Upgrade TLS with the client if SSL is enabled.
if _, err = clientIO.UpgradeToServerTLS(serverTLSConfig); err != nil {
return false, err
return err
}
} else {
// Rewrite the packet with ClientSSL enabled because we always connect to TiDB with TLS.
Expand All @@ -87,41 +79,39 @@ func (auth *Authenticator) handshakeFirstTime(clientIO, backendIO *pnet.PacketIO
clientPkt = pktWithSSL
}
if err = backendIO.WritePacket(clientPkt, true); err != nil {
return false, err
return err
}
// Always upgrade TLS with the server.
auth.backendTLSConfig = backendTLSConfig
if err = backendIO.UpgradeToClientTLS(backendTLSConfig); err != nil {
return false, err
return err
}
if sslEnabled {
// Read from the client again, where the capability may not contain ClientSSL this time.
if clientPkt, err = clientIO.ReadPacket(); err != nil {
return false, err
return err
}
}
// Send the response again.
if err = backendIO.WritePacket(clientPkt, true); err != nil {
return false, err
return err
}
auth.readHandshakeResponse(clientPkt)

// verify password
for {
serverPkt, err = forwardMsg(backendIO, clientIO)
if err != nil {
return false, err
return err
}
switch serverPkt[0] {
case mysql.OKHeader:
logutil.BgLogger().Debug("parse client handshake response finished", zap.String("authInfo", auth.String()))
return true, nil
return nil
case mysql.ErrHeader:
return false, nil
return pnet.ParseErrorPacket(serverPkt)
default: // mysql.AuthSwitchRequest, ShaCommand
clientPkt, err = forwardMsg(clientIO, backendIO)
if err != nil {
return false, err
if _, err = forwardMsg(clientIO, backendIO); err != nil {
return err
}
}
}
Expand Down Expand Up @@ -159,8 +149,7 @@ func (auth *Authenticator) handshakeSecondTime(backendIO *pnet.PacketIO, session
}

tokenBytes := hack.Slice(sessionToken)
err = auth.writeAuthHandshake(backendIO, tokenBytes)
if err != nil {
if err = auth.writeAuthHandshake(backendIO, tokenBytes); err != nil {
return err
}

Expand All @@ -169,11 +158,10 @@ func (auth *Authenticator) handshakeSecondTime(backendIO *pnet.PacketIO, session

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability uint32, err error) {
if serverPkt, err = backendIO.ReadPacket(); err != nil {
err = errors.Trace(err)
return
}
if serverPkt[0] == mysql.ErrHeader {
err = errors.New("read initial handshake error")
err = pnet.ParseErrorPacket(serverPkt)
return
}
capability = pnet.ParseInitialHandshake(serverPkt)
Expand Down Expand Up @@ -211,7 +199,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
case mysql.OKHeader:
return nil
case mysql.ErrHeader:
return errors.New("auth failed")
return pnet.ParseErrorPacket(data)
default: // mysql.AuthSwitchRequest, ShaCommand:
return errors.Errorf("read unexpected command: %#x", data[0])
}
Expand Down
11 changes: 2 additions & 9 deletions pkg/proxy/backend/backend_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,20 @@ import (
"time"

pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/util/arena"
)

const (
DialTimeout = 5 * time.Second
)

type connectionPhase byte

type BackendConnection struct {
pkt *pnet.PacketIO // a helper to read and write data in packet format.
alloc arena.Allocator
phase connectionPhase
capability uint32
address string
pkt *pnet.PacketIO // a helper to read and write data in packet format.
address string
}

func NewBackendConnection(address string) *BackendConnection {
return &BackendConnection{
address: address,
alloc: arena.NewAllocator(32 * 1024),
}
}

Expand Down
33 changes: 16 additions & 17 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,8 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, serverAddr string, c
return err
}
backendIO := mgr.backendConn.PacketIO()
succeed, err := mgr.authenticator.handshakeFirstTime(clientIO, backendIO, serverTLSConfig, backendTLSConfig)
if err != nil {
if err := mgr.authenticator.handshakeFirstTime(clientIO, backendIO, serverTLSConfig, backendTLSConfig); err != nil {
return err
} else if !succeed {
return errors.New("server returns auth failure")
}
if mgr.authenticator.capability&mysql.ClientProtocol41 == 0 {
return errors.New("client must support CLIENT_PROTOCOL_41 capability")
}
childCtx, cancelFunc := context.WithCancel(ctx)
go mgr.processSignals(childCtx)
Expand All @@ -95,29 +89,34 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
mgr.processLock.Lock()
defer mgr.processLock.Unlock()
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, succeed, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), waitingRedirect)
if err != nil {
holdRequest, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), waitingRedirect)
if err != nil && !IsMySQLError(err) {
return err
}
switch request[0] {
case mysql.ComQuit:
return nil
case mysql.ComChangeUser:
if succeed {
if err == nil {
switch request[0] {
case mysql.ComQuit:
return nil
case mysql.ComChangeUser:
username, db := pnet.ParseChangeUser(request)
mgr.authenticator.changeUser(username, db)
return nil
}
return nil
}
// Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements.
if waitingRedirect && mgr.cmdProcessor.canRedirect() {
if err = mgr.tryRedirect(ctx); err != nil {
return err
}
if holdRequest {
_, _, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), false)
_, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendConn.PacketIO(), false)
}
if err != nil && !IsMySQLError(err) {
return err
}
}
return err
// Ignore MySQL errors, only return unexpected errors.
return nil
}

func (mgr *BackendConnManager) SetEventReceiver(receiver router.ConnEventReceiver) {
Expand Down
9 changes: 9 additions & 0 deletions pkg/proxy/backend/cmd_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,12 @@ func (cp *CmdProcessor) hasPendingPreparedStmts() bool {
}
return false
}

// IsMySQLError returns true if the error is a MySQL error.
func IsMySQLError(err error) bool {
if err == nil {
return false
}
_, ok := err.(*gomysql.MyError)
return ok
}
Loading

0 comments on commit 4bf1155

Please sign in to comment.