Skip to content

Commit

Permalink
server: refact the code structure (#45228)
Browse files Browse the repository at this point in the history
ref #44940
  • Loading branch information
hawkingrei authored Jul 7, 2023
1 parent 351e379 commit cf0ae34
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 253 deletions.
4 changes: 3 additions & 1 deletion server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "server",
srcs = [
"buffered_read_conn.go",
"conn.go",
"conn_stmt.go",
"driver.go",
Expand Down Expand Up @@ -55,6 +54,7 @@ go_library(
"//privilege/privileges/ldap",
"//server/internal/column",
"//server/internal/dump",
"//server/internal/handshake",
"//server/internal/parse",
"//server/internal/util",
"//server/metrics",
Expand Down Expand Up @@ -185,6 +185,8 @@ go_test(
"//parser/mysql",
"//planner/core",
"//server/internal/column",
"//server/internal/handshake",
"//server/internal/parse",
"//server/internal/util",
"//session",
"//sessionctx",
Expand Down
180 changes: 22 additions & 158 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ import (
"time"
"unsafe"

"github.com/klauspost/compress/zstd"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/config"
Expand All @@ -77,6 +76,8 @@ import (
"github.com/pingcap/tidb/privilege/privileges/ldap"
"github.com/pingcap/tidb/server/internal/column"
"github.com/pingcap/tidb/server/internal/dump"
"github.com/pingcap/tidb/server/internal/handshake"
"github.com/pingcap/tidb/server/internal/parse"
util2 "github.com/pingcap/tidb/server/internal/util"
server_metrics "github.com/pingcap/tidb/server/metrics"
"github.com/pingcap/tidb/session"
Expand Down Expand Up @@ -128,16 +129,16 @@ func newClientConn(s *Server) *clientConn {
// clientConn represents a connection between server and client, it maintains connection specific state,
// handles client query.
type clientConn struct {
pkt *packetIO // a helper to read and write data in packet format.
bufReadConn *bufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn.
tlsConn *tls.Conn // TLS connection, nil if not TLS.
server *Server // a reference of server instance.
capability uint32 // client capability affects the way server handles client request.
connectionID uint64 // atomically allocated by a global variable, unique in process scope.
user string // user of the client.
dbname string // default database name.
salt []byte // random bytes used for authentication.
alloc arena.Allocator // an memory allocator for reducing memory allocation.
pkt *packetIO // a helper to read and write data in packet format.
bufReadConn *util2.BufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn.
tlsConn *tls.Conn // TLS connection, nil if not TLS.
server *Server // a reference of server instance.
capability uint32 // client capability affects the way server handles client request.
connectionID uint64 // atomically allocated by a global variable, unique in process scope.
user string // user of the client.
dbname string // default database name.
salt []byte // random bytes used for authentication.
alloc arena.Allocator // an memory allocator for reducing memory allocation.
chunkAlloc chunk.Allocator
lastPacket []byte // latest sql query string, currently used for logging error.
// ShowProcess() and mysql.ComChangeUser both visit this field, ShowProcess() read information through
Expand Down Expand Up @@ -448,143 +449,6 @@ func (cc *clientConn) getSessionVarsWaitTimeout(ctx context.Context) uint64 {
return waitTimeout
}

type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
AuthPlugin string
Attrs map[string]string
ZstdLevel zstd.EncoderLevel
}

// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
func parseHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return 0, mysql.ErrMalformPacket
}

offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23

return offset, nil
}

// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data))
err = mysql.ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1

if packet.Capability&mysql.ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
if data[offset] == 0x1 { // No auth data
offset += 2
} else {
num, null, off := util2.ParseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
}
} else if packet.Capability&mysql.ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}

if packet.Capability&mysql.ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset += idx + 1
}
}

if packet.Capability&mysql.ClientPluginAuth > 0 {
idx := bytes.IndexByte(data[offset:], 0)
s := offset
f := offset + idx
if s < f { // handle unexpected bad packets
packet.AuthPlugin = string(data[s:f])
}
offset += idx + 1
}

if packet.Capability&mysql.ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
if num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]); !null {
offset += intOff // Length of variable length encoded integer itself in bytes
row := data[offset : offset+int(num)]
attrs, err := parseAttrs(row)
if err != nil {
logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err))
return nil
}
packet.Attrs = attrs
offset += int(num) // Length of attributes
}
}

if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 {
packet.ZstdLevel = zstd.EncoderLevelFromZstd(int(data[offset]))
}

return nil
}

func parseAttrs(data []byte) (map[string]string, error) {
attrs := make(map[string]string)
pos := 0
for pos < len(data) {
key, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
value, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off

attrs[string(key)] = string(value)
}
return attrs, nil
}

func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error {
// Read a packet. It may be a SSLRequest or HandshakeResponse.
data, err := cc.readPacket()
Expand All @@ -598,7 +462,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}

var resp handshakeResponse41
var resp handshake.Response41
var pos int

if len(data) < 2 {
Expand All @@ -611,7 +475,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
logutil.Logger(ctx).Error("ClientProtocol41 flag is not set, please upgrade client")
return errNotSupportedAuthMode
}
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
pos, err = parse.HandshakeResponseHeader(ctx, &resp, data)
if err != nil {
terror.Log(err)
return err
Expand Down Expand Up @@ -645,7 +509,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
logutil.Logger(ctx).Warn("read handshake response failure after upgrade to TLS", zap.Error(err))
return err
}
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
pos, err = parse.HandshakeResponseHeader(ctx, &resp, data)
if err != nil {
terror.Log(err)
return err
Expand All @@ -660,7 +524,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
}

// Read the remaining part of the packet.
err = parseHandshakeResponseBody(ctx, &resp, data, pos)
err = parse.HandshakeResponseBody(ctx, &resp, data, pos)
if err != nil {
terror.Log(err)
return err
Expand Down Expand Up @@ -707,7 +571,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}

func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error {
func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshake.Response41) error {
if resp.Capability&mysql.ClientPluginAuth > 0 {
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
Expand Down Expand Up @@ -743,7 +607,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo
}

// authSha implements the caching_sha2_password specific part of the protocol.
func (cc *clientConn) authSha(ctx context.Context, resp handshakeResponse41) ([]byte, error) {
func (cc *clientConn) authSha(ctx context.Context, resp handshake.Response41) ([]byte, error) {
const (
shaCommand = 1
requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel.
Expand Down Expand Up @@ -781,7 +645,7 @@ func (cc *clientConn) authSha(ctx context.Context, resp handshakeResponse41) ([]

// authSM3 implements the tidb_sm3_password specific part of the protocol.
// tidb_sm3_password is very similar to caching_sha2_password.
func (cc *clientConn) authSM3(ctx context.Context, resp handshakeResponse41) ([]byte, error) {
func (cc *clientConn) authSM3(ctx context.Context, resp handshake.Response41) ([]byte, error) {
// If no password is specified, we don't send the FastAuthFail to do the full authentication
// as that doesn't make sense without a password and confuses the client.
// https://github.com/pingcap/tidb/issues/40831
Expand Down Expand Up @@ -880,7 +744,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
}

// Check if the Authentication Plugin of the server, client and user configuration matches
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) {
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshake.Response41) ([]byte, error) {
// Open a context unless this was done before.
if ctx := cc.getCtx(); ctx == nil {
err := cc.openSession()
Expand Down Expand Up @@ -2470,7 +2334,7 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs cursorRes
}

func (cc *clientConn) setConn(conn net.Conn) {
cc.bufReadConn = newBufferedReadConn(conn)
cc.bufReadConn = util2.NewBufferedReadConn(conn)
if cc.pkt == nil {
cc.pkt = newPacketIO(cc.bufReadConn)
} else {
Expand Down Expand Up @@ -2526,7 +2390,7 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
if err != nil {
return err
}
fakeResp := &handshakeResponse41{
fakeResp := &handshake.Response41{
Auth: pass,
AuthPlugin: pluginName,
Capability: cc.capability,
Expand Down
Loading

0 comments on commit cf0ae34

Please sign in to comment.