From 21d6d9a965bcf0abb7b73b390c7c362f78a0255e Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 13 Jul 2022 20:01:03 +0800 Subject: [PATCH 01/29] *: Support for sm3_password authentication --- executor/simple.go | 6 +- parser/ast/misc.go | 6 + parser/auth/caching_sha2.go | 14 +- parser/auth/caching_sha2_test.go | 2 +- parser/auth/sm3.go | 212 +++++++++++++++++++++++++++++ parser/auth/sm3_test.go | 39 ++++++ parser/go.mod | 1 + parser/go.sum | 64 +++++++++ parser/mysql/const.go | 2 + privilege/privileges/privileges.go | 21 ++- server/conn.go | 3 + server/conn_test.go | 166 +++++++++++++++++++++- sessionctx/variable/sysvar.go | 2 +- 13 files changed, 522 insertions(+), 16 deletions(-) create mode 100644 parser/auth/sm3.go create mode 100644 parser/auth/sm3_test.go diff --git a/executor/simple.go b/executor/simple.go index a7794d4fb8498..8c03da4959c02 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -838,7 +838,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm } switch authPlugin { - case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSocket: + case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password, mysql.AuthSocket: default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } @@ -982,7 +982,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) spec.AuthOpt.AuthPlugin = authplugin } switch spec.AuthOpt.AuthPlugin { - case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSocket, "": + case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password, mysql.AuthSocket, "": default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } @@ -1463,6 +1463,8 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error switch authplugin { case mysql.AuthCachingSha2Password: pwd = auth.NewSha2Password(s.Password) + case mysql.AuthSM3Password: + pwd = auth.NewSM3Password(s.Password) case mysql.AuthSocket: e.ctx.GetSessionVars().StmtCtx.AppendNote(ErrSetPasswordAuthPlugin.GenWithStackByArgs(u, h)) pwd = "" diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 08e14575c53eb..ea180290c969a 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -1322,6 +1322,8 @@ func (n *UserSpec) EncodedPassword() (string, bool) { switch opt.AuthPlugin { case mysql.AuthCachingSha2Password: return auth.NewSha2Password(opt.AuthString), true + case mysql.AuthSM3Password: + return auth.NewSM3Password(opt.AuthString), true case mysql.AuthSocket: return "", true default: @@ -1340,6 +1342,10 @@ func (n *UserSpec) EncodedPassword() (string, bool) { if len(opt.HashString) != mysql.SHAPWDHashLen { return "", false } + case mysql.AuthSM3Password: + if len(opt.HashString) != mysql.SM3PWDHashLen { + return "", false + } case "", mysql.AuthNativePassword: if len(opt.HashString) != (mysql.PWDHashLen+1) || !strings.HasPrefix(opt.HashString, "*") { return "", false diff --git a/parser/auth/caching_sha2.go b/parser/auth/caching_sha2.go index 125a7f615c495..ffc0857c97df2 100644 --- a/parser/auth/caching_sha2.go +++ b/parser/auth/caching_sha2.go @@ -177,24 +177,24 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { return buf.String() } -// Checks if a MySQL style caching_sha2 authentication string matches a password +// CheckShaPassword checks if a MySQL style caching_sha2 authentication string matches a password func CheckShaPassword(pwhash []byte, password string) (bool, error) { - pwhash_parts := bytes.Split(pwhash, []byte("$")) - if len(pwhash_parts) != 4 { + pwhashParts := bytes.Split(pwhash, []byte("$")) + if len(pwhashParts) != 4 { return false, errors.New("failed to decode hash parts") } - hash_type := string(pwhash_parts[1]) - if hash_type != "A" { + hashType := string(pwhashParts[1]) + if hashType != "A" { return false, errors.New("digest type is incompatible") } - iterations, err := strconv.Atoi(string(pwhash_parts[2])) + iterations, err := strconv.Atoi(string(pwhashParts[2])) if err != nil { return false, errors.New("failed to decode iterations") } iterations = iterations * ITERATION_MULTIPLIER - salt := pwhash_parts[3][:SALT_LENGTH] + salt := pwhashParts[3][:SALT_LENGTH] newHash := sha256crypt(password, salt, iterations) diff --git a/parser/auth/caching_sha2_test.go b/parser/auth/caching_sha2_test.go index 6af1d5dc859f9..a583c7f3e6e64 100644 --- a/parser/auth/caching_sha2_test.go +++ b/parser/auth/caching_sha2_test.go @@ -58,7 +58,7 @@ func TestCheckShaPasswordIterationsInvalid(t *testing.T) { require.Error(t, err) } -// The output from NewSha2Password is not stable as the hash is based on the genrated salt. +// The output from NewSha2Password is not stable as the hash is based on the generated salt. // This is why CheckShaPassword is used here. func TestNewSha2Password(t *testing.T) { pwd := "testpwd" diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go new file mode 100644 index 0000000000000..1e5f481300a54 --- /dev/null +++ b/parser/auth/sm3.go @@ -0,0 +1,212 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +// The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml +// This implementation of 'type SM3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 + +import ( + "bytes" + "encoding/binary" + "hash" + "runtime/debug" +) + +type SM3 struct { + digest [8]uint32 // digest represents the partial evaluation of V + length uint64 // length of the message + unhandleMsg []byte +} + +func (sm3 *SM3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z } + +func (sm3 *SM3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } + +func (sm3 *SM3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z } + +func (sm3 *SM3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } + +func (sm3 *SM3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) } + +func (sm3 *SM3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) } + +func (sm3 *SM3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } + +func (sm3 *SM3) pad() []byte { + msg := sm3.unhandleMsg + // Append '1' + msg = append(msg, 0x80) + // Append until the resulting message length (in bits) is congruent to 448 (mod 512) + blockSize := 64 + for i := len(msg); i%blockSize != 56; i++ { + msg = append(msg, 0x00) + } + // append message length + msg = append(msg, uint8(sm3.length>>56&0xff)) + msg = append(msg, uint8(sm3.length>>48&0xff)) + msg = append(msg, uint8(sm3.length>>40&0xff)) + msg = append(msg, uint8(sm3.length>>32&0xff)) + msg = append(msg, uint8(sm3.length>>24&0xff)) + msg = append(msg, uint8(sm3.length>>16&0xff)) + msg = append(msg, uint8(sm3.length>>8&0xff)) + msg = append(msg, uint8(sm3.length>>0&0xff)) + + if len(msg)%64 != 0 { + panic("------SM3 Pad: error msgLen =") + } + return msg +} + +func (sm3 *SM3) update(msg []byte) [8]uint32 { + var w [68]uint32 + var w1 [64]uint32 + + a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7] + for len(msg) >= 64 { + for i := 0; i < 16; i++ { + w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)]) + } + for i := 16; i < 68; i++ { + w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6] + } + for i := 0; i < 64; i++ { + w1[i] = w[i] ^ w[i+4] + } + A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h + for i := 0; i < 16; i++ { + SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7) + SS2 := SS1 ^ sm3.leftRotate(A, 12) + TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i] + TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i] + D = C + C = sm3.leftRotate(B, 9) + B = A + A = TT1 + H = G + G = sm3.leftRotate(F, 19) + F = E + E = sm3.p0(TT2) + } + for i := 16; i < 64; i++ { + SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7) + SS2 := SS1 ^ sm3.leftRotate(A, 12) + TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i] + TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i] + D = C + C = sm3.leftRotate(B, 9) + B = A + A = TT1 + H = G + G = sm3.leftRotate(F, 19) + F = E + E = sm3.p0(TT2) + } + a ^= A + b ^= B + c ^= C + d ^= D + e ^= E + f ^= F + g ^= G + h ^= H + msg = msg[64:] + } + var digest [8]uint32 + digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6], digest[7] = a, b, c, d, e, f, g, h + return digest +} + +// New creates a new SM3 hashing instance. +func New() hash.Hash { + var sm3 SM3 + sm3.Reset() + return &sm3 +} + +// BlockSize returns the hash's underlying block size. +// The Write method must be able to accept any amount of data, +// but it may operate more efficiently if all writes are a multiple of the block size. +func (sm3 *SM3) BlockSize() int { return 64 } + +// Size returns the number of bytes Sum will return. +func (sm3 *SM3) Size() int { return 32 } + +// Reset clears the internal state by zeroing bytes in the state buffer. +// This can be skipped for a newly-created hash state; the default zero-allocated state is correct. +func (sm3 *SM3) Reset() { + // Reset digest + sm3.digest[0] = 0x7380166f + sm3.digest[1] = 0x4914b2b9 + sm3.digest[2] = 0x172442d7 + sm3.digest[3] = 0xda8a0600 + sm3.digest[4] = 0xa96f30bc + sm3.digest[5] = 0x163138aa + sm3.digest[6] = 0xe38dee4d + sm3.digest[7] = 0xb0fb0e4e + + sm3.length = 0 + sm3.unhandleMsg = []byte{} +} + +// Write (via the embedded io.Writer interface) adds more data to the running hash. +// It never returns an error. +func (sm3 *SM3) Write(p []byte) (int, error) { + toWrite := len(p) + sm3.length += uint64(len(p) * 8) + msg := append(sm3.unhandleMsg, p...) + nblocks := len(msg) / sm3.BlockSize() + sm3.digest = sm3.update(msg) + sm3.unhandleMsg = msg[nblocks*sm3.BlockSize():] + + return toWrite, nil +} + +// Sum appends the current hash to b and returns the resulting slice. +// It does not change the underlying hash state. +func (sm3 *SM3) Sum(in []byte) []byte { + _, _ = sm3.Write(in) + msg := sm3.pad() + // Finalize + digest := sm3.update(msg) + + // save hash to in + needed := sm3.Size() + if cap(in)-len(in) < needed { + newIn := make([]byte, len(in), len(in)+needed) + copy(newIn, in) + in = newIn + } + out := in[len(in) : len(in)+needed] + for i := 0; i < 8; i++ { + binary.BigEndian.PutUint32(out[i*4:], digest[i]) + } + return out +} + +// CheckSM3Password checks if a SM3 authentication string matches a password +func CheckSM3Password(pwhash []byte, password string) (bool, error) { + debug.PrintStack() + h := New() + h.Write([]byte(password)) + sum := h.Sum(nil) + return bytes.Equal(pwhash, sum), nil +} + +func NewSM3Password(pwd string) string { + debug.PrintStack() + h := New() + h.Write([]byte(pwd)) + sum := h.Sum(nil) + return string(sum) +} diff --git a/parser/auth/sm3_test.go b/parser/auth/sm3_test.go new file mode 100644 index 0000000000000..e4affa47d29aa --- /dev/null +++ b/parser/auth/sm3_test.go @@ -0,0 +1,39 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckSM3Password(t *testing.T) { + pwd1 := "test" + pwd2 := "no_pass" + val := "55e12e91650d2fec56ec74e1d3e4ddbfce2ef3a65890c2a19ecf88a307e76a23" + + sm3pwd := NewSM3Password(pwd1) + require.Equal(t, val, fmt.Sprintf("%x", sm3pwd)) + r, err := CheckSM3Password([]byte(sm3pwd), pwd1) + require.NoError(t, err) + require.True(t, r) + + sm3pwd = NewSM3Password(pwd2) + require.NotEqual(t, val, fmt.Sprintf("%x", sm3pwd)) + r, err = CheckSM3Password([]byte(sm3pwd), pwd1) + require.NoError(t, err) + require.False(t, r) +} diff --git a/parser/go.mod b/parser/go.mod index 1f49f53d36814..071b9401a90f6 100644 --- a/parser/go.mod +++ b/parser/go.mod @@ -8,6 +8,7 @@ require ( github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 github.com/stretchr/testify v1.7.0 + github.com/tjfoc/gmsm v1.4.1 go.uber.org/goleak v1.1.10 go.uber.org/zap v1.18.1 golang.org/x/exp v0.0.0-20220428152302-39d4317da171 diff --git a/parser/go.sum b/parser/go.sum index 267fe82580882..abd195c60a412 100644 --- a/parser/go.sum +++ b/parser/go.sum @@ -1,7 +1,11 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/cznic/sortutil v0.0.0-20181122101858-f5f958428db8 h1:LpMLYGyy67BoAFGda1NeOBQwqlv7nUXpm+rIVHGxZZ4= @@ -11,8 +15,26 @@ github.com/cznic/strutil v0.0.0-20171016134553-529a34b1c186/go.mod h1:AHHPPPXTw0 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -27,6 +49,7 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -34,6 +57,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= +github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= @@ -47,25 +72,62 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.18.1 h1:CSUJ2mjFszzEWt4CdKISEuChVIXGBn3lAPwkRGyVrc4= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20220428152302-39d4317da171 h1:TfdoLivD44QwvssI9Sv1xwa5DcL5XQr4au4sZ2F2NV4= golang.org/x/exp v0.0.0-20220428152302-39d4317da171/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI= golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -77,6 +139,8 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= modernc.org/fileutil v1.0.0/go.mod h1:JHsWpkrk/CnVV1H/eGlFf85BEpfkrp56ro8nojIq9Q8= modernc.org/golex v1.0.1 h1:EYKY1a3wStt0RzHaH8mdSRNg78Ub0OHxYfCRWw35YtM= modernc.org/golex v1.0.1/go.mod h1:QCA53QtsT1NdGkaZZkF5ezFwk4IXh4BGNafAARTC254= diff --git a/parser/mysql/const.go b/parser/mysql/const.go index 81e551a125e01..e25aff8850dfc 100644 --- a/parser/mysql/const.go +++ b/parser/mysql/const.go @@ -166,6 +166,7 @@ const ( const ( AuthNativePassword = "mysql_native_password" // #nosec G101 AuthCachingSha2Password = "caching_sha2_password" // #nosec G101 + AuthSM3Password = "sm3_password" // #nosec G101 AuthSocket = "auth_socket" ) @@ -231,6 +232,7 @@ const MaxTypeSetMembers = 64 // PWDHashLen is the length of mysql_native_password's hash. const PWDHashLen = 40 // excluding the '*' const SHAPWDHashLen = 70 +const SM3PWDHashLen = 64 // Command2Str is the command information to command name. var Command2Str = map[byte]string{ diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index e6633b03f1d5d..0ff14a278a83c 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -198,6 +198,14 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { return false } + if record.AuthPlugin == mysql.AuthSM3Password { + if len(pwd) == mysql.SM3PWDHashLen { + return true + } + logutil.BgLogger().Error("user password from system DB not like a sm3_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + return false + } + if record.AuthPlugin == mysql.AuthSocket { return true } @@ -346,12 +354,21 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } } else if record.AuthPlugin == mysql.AuthCachingSha2Password { - authok, err := auth.CheckShaPassword([]byte(pwd), string(authentication)) + authOK, err := auth.CheckShaPassword([]byte(pwd), string(authentication)) if err != nil { logutil.BgLogger().Error("Failed to check caching_sha2_password", zap.Error(err)) } - if !authok { + if !authOK { + return + } + } else if record.AuthPlugin == mysql.AuthSM3Password { + authOK, err := auth.CheckSM3Password([]byte(pwd), string(authentication)) + if err != nil { + logutil.BgLogger().Error("Failed to check sm3_password", zap.Error(err)) + } + + if !authOK { return } } else if record.AuthPlugin == mysql.AuthSocket { diff --git a/server/conn.go b/server/conn.go index e5289e4fec5af..f147e8d060b66 100644 --- a/server/conn.go +++ b/server/conn.go @@ -729,6 +729,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con if err != nil { return err } + case mysql.AuthSM3Password: + // TODO case mysql.AuthNativePassword: case mysql.AuthSocket: default: @@ -755,6 +757,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo switch resp.AuthPlugin { case mysql.AuthCachingSha2Password: + case mysql.AuthSM3Password: case mysql.AuthNativePassword: case mysql.AuthSocket: default: diff --git a/server/conn_test.go b/server/conn_test.go index f9661226ae1c3..7780f5a7afa00 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1054,7 +1054,30 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword)) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // client trying to authenticate with sm3_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthSM3Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // MySQL 5.1 or older client, without authplugin support @@ -1125,6 +1148,29 @@ func TestHandleAuthPlugin(t *testing.T) { require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + // client trying to authenticate with sm3_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthSM3Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + // MySQL 5.1 or older client, without authplugin support cc = &clientConn{ connectionID: 1, @@ -1194,6 +1240,121 @@ func TestHandleAuthPlugin(t *testing.T) { require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + // client trying to authenticate with sm3_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthSM3Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.Error(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) + + // === Target account has sm3_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"sm3_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // client trying to authenticate with sm3_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthSM3Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + // MySQL 5.1 or older client, without authplugin support cc = &clientConn{ connectionID: 1, @@ -1253,9 +1414,8 @@ func TestAuthPlugin2(t *testing.T) { require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) respAuthSwitch, err := cc.checkAuthPlugin(ctx, &resp) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - require.Equal(t, respAuthSwitch, []byte(mysql.AuthNativePassword)) + require.Equal(t, []byte(mysql.AuthNativePassword), respAuthSwitch) require.NoError(t, err) - } func TestMaxAllowedPacket(t *testing.T) { diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index dba7781dfbe1f..2c513c9798de6 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -657,7 +657,7 @@ var defaultSysVars = []*SysVar{ return nil }}, {Scope: ScopeGlobal, Name: SkipNameResolve, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: DefaultAuthPlugin, Value: mysql.AuthNativePassword, Type: TypeEnum, PossibleValues: []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password}}, + {Scope: ScopeGlobal, Name: DefaultAuthPlugin, Value: mysql.AuthNativePassword, Type: TypeEnum, PossibleValues: []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password}}, {Scope: ScopeGlobal, Name: TiDBPersistAnalyzeOptions, Value: BoolToOnOff(DefTiDBPersistAnalyzeOptions), Type: TypeBool, GetGlobal: func(s *SessionVars) (string, error) { return BoolToOnOff(PersistAnalyzeOptions.Load()), nil From 77ec1f730e49e9302727cc917d8b6a96c71543b6 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 13 Jul 2022 20:45:58 +0800 Subject: [PATCH 02/29] Fix --- parser/auth/BUILD.bazel | 2 ++ parser/go.mod | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/parser/auth/BUILD.bazel b/parser/auth/BUILD.bazel index 93ad9d22d2d29..2e9341368cdf1 100644 --- a/parser/auth/BUILD.bazel +++ b/parser/auth/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "auth.go", "caching_sha2.go", "mysql_native_password.go", + "sm3.go", ], importpath = "github.com/pingcap/tidb/parser/auth", visibility = ["//visibility:public"], @@ -21,6 +22,7 @@ go_test( srcs = [ "caching_sha2_test.go", "mysql_native_password_test.go", + "sm3_test.go", ], embed = [":auth"], deps = ["@com_github_stretchr_testify//require"], diff --git a/parser/go.mod b/parser/go.mod index 071b9401a90f6..1f49f53d36814 100644 --- a/parser/go.mod +++ b/parser/go.mod @@ -8,7 +8,6 @@ require ( github.com/pingcap/errors v0.11.5-0.20210425183316-da1aaba5fb63 github.com/pingcap/log v0.0.0-20210625125904-98ed8e2eb1c7 github.com/stretchr/testify v1.7.0 - github.com/tjfoc/gmsm v1.4.1 go.uber.org/goleak v1.1.10 go.uber.org/zap v1.18.1 golang.org/x/exp v0.0.0-20220428152302-39d4317da171 From 3676087610bc21f8023dc7718989ee826ecc47f6 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 13 Jul 2022 21:26:36 +0800 Subject: [PATCH 03/29] Fix --- DEPS.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DEPS.bzl b/DEPS.bzl index 909a95484d13c..e5f9e32920a61 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -2502,8 +2502,8 @@ def go_deps(): name = "com_github_pingcap_kvproto", build_file_proto_mode = "disable_global", importpath = "github.com/pingcap/kvproto", - sum = "h1:nP2wmyw9JTRsk5rm+tZtfAso6c/1FvuaFNbXTaYz3FE=", - version = "v0.0.0-20220705053936-aa9c2d20cd2a", + sum = "h1:VKMmvYhtG28j1sCCBdq4s+V9UOYqNgQ6CQviQwOgTeg=", + version = "v0.0.0-20220705090230-a5d4ffd2ba33", ) go_repository( name = "com_github_pingcap_log", From d080a8826bfa8b7f353f6e3ff13140697eb681b2 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 14 Jul 2022 23:31:37 +0800 Subject: [PATCH 04/29] implement sm3 like caching_sha2 --- parser/auth/caching_sha2_test.go | 8 +- parser/auth/sm3.go | 231 ++++++++++++++++++++++++++----- parser/auth/sm3_test.go | 64 +++++++-- server/conn.go | 36 ++++- 4 files changed, 286 insertions(+), 53 deletions(-) diff --git a/parser/auth/caching_sha2_test.go b/parser/auth/caching_sha2_test.go index a583c7f3e6e64..51fe4dcaf692d 100644 --- a/parser/auth/caching_sha2_test.go +++ b/parser/auth/caching_sha2_test.go @@ -20,11 +20,11 @@ import ( "github.com/stretchr/testify/require" ) -var foobarPwdHash, _ = hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") +var foobarPwdSHA2Hash, _ = hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") func TestCheckShaPasswordGood(t *testing.T) { pwd := "foobar" - r, err := CheckShaPassword(foobarPwdHash, pwd) + r, err := CheckShaPassword(foobarPwdSHA2Hash, pwd) require.NoError(t, err) require.True(t, r) } @@ -44,7 +44,7 @@ func TestCheckShaPasswordShort(t *testing.T) { require.Error(t, err) } -func TestCheckShaPasswordDigetTypeIncompatible(t *testing.T) { +func TestCheckShaPasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckShaPassword(pwhash, pwd) @@ -76,7 +76,7 @@ func TestNewSha2Password(t *testing.T) { func BenchmarkShaPassword(b *testing.B) { for i := 0; i < b.N; i++ { - m, err := CheckShaPassword(foobarPwdHash, "foobar") + m, err := CheckShaPassword(foobarPwdSHA2Hash, "foobar") require.Nil(b, err) require.True(b, m) } diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 1e5f481300a54..cb83efc6e9b68 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -14,36 +14,41 @@ package auth // The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml -// This implementation of 'type SM3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 +// This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 +// Some other references: +// https://tools.ietf.org/id/draft-oscca-cfrg-sm3-01.html import ( "bytes" + "crypto/rand" + "crypto/sha256" "encoding/binary" - "hash" - "runtime/debug" + "errors" + "fmt" + "strconv" ) -type SM3 struct { +type sm3 struct { digest [8]uint32 // digest represents the partial evaluation of V length uint64 // length of the message unhandleMsg []byte } -func (sm3 *SM3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z } +func (sm3 *sm3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z } -func (sm3 *SM3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } +func (sm3 *sm3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } -func (sm3 *SM3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z } +func (sm3 *sm3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z } -func (sm3 *SM3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } +func (sm3 *sm3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } -func (sm3 *SM3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) } +func (sm3 *sm3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) } -func (sm3 *SM3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) } +func (sm3 *sm3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) } -func (sm3 *SM3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } +func (sm3 *sm3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } -func (sm3 *SM3) pad() []byte { +func (sm3 *sm3) pad() []byte { msg := sm3.unhandleMsg // Append '1' msg = append(msg, 0x80) @@ -63,12 +68,12 @@ func (sm3 *SM3) pad() []byte { msg = append(msg, uint8(sm3.length>>0&0xff)) if len(msg)%64 != 0 { - panic("------SM3 Pad: error msgLen =") + panic("------sm3 Pad: error msgLen =") } return msg } -func (sm3 *SM3) update(msg []byte) [8]uint32 { +func (sm3 *sm3) update(msg []byte) [8]uint32 { var w [68]uint32 var w1 [64]uint32 @@ -127,24 +132,17 @@ func (sm3 *SM3) update(msg []byte) [8]uint32 { return digest } -// New creates a new SM3 hashing instance. -func New() hash.Hash { - var sm3 SM3 - sm3.Reset() - return &sm3 -} - // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount of data, // but it may operate more efficiently if all writes are a multiple of the block size. -func (sm3 *SM3) BlockSize() int { return 64 } +func (sm3 *sm3) BlockSize() int { return 64 } // Size returns the number of bytes Sum will return. -func (sm3 *SM3) Size() int { return 32 } +func (sm3 *sm3) Size() int { return 32 } // Reset clears the internal state by zeroing bytes in the state buffer. // This can be skipped for a newly-created hash state; the default zero-allocated state is correct. -func (sm3 *SM3) Reset() { +func (sm3 *sm3) Reset() { // Reset digest sm3.digest[0] = 0x7380166f sm3.digest[1] = 0x4914b2b9 @@ -161,7 +159,7 @@ func (sm3 *SM3) Reset() { // Write (via the embedded io.Writer interface) adds more data to the running hash. // It never returns an error. -func (sm3 *SM3) Write(p []byte) (int, error) { +func (sm3 *sm3) Write(p []byte) (int, error) { toWrite := len(p) sm3.length += uint64(len(p) * 8) msg := append(sm3.unhandleMsg, p...) @@ -174,7 +172,7 @@ func (sm3 *SM3) Write(p []byte) (int, error) { // Sum appends the current hash to b and returns the resulting slice. // It does not change the underlying hash state. -func (sm3 *SM3) Sum(in []byte) []byte { +func (sm3 *sm3) Sum(in []byte) []byte { _, _ = sm3.Write(in) msg := sm3.pad() // Finalize @@ -194,19 +192,178 @@ func (sm3 *SM3) Sum(in []byte) []byte { return out } -// CheckSM3Password checks if a SM3 authentication string matches a password +// SM3 returns the sm3 checksum of the data. +func SM3(data []byte) []byte { + var h sm3 + h.Reset() + h.Write(data) + return h.Sum(nil) +} + +func SM3String(pwd string) string { + var h sm3 + h.Reset() + h.Write([]byte(pwd)) + return string(h.Sum(nil)) +} + +func sm3crypt(plaintext string, salt []byte, iterations int) string { + // Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt + + // 1, 2, 3 + bufA := bytes.NewBuffer(make([]byte, 0, 4096)) + bufA.Write([]byte(plaintext)) + bufA.Write(salt) + + // 4, 5, 6, 7, 8 + bufB := bytes.NewBuffer(make([]byte, 0, 4096)) + bufB.Write([]byte(plaintext)) + bufB.Write(salt) + bufB.Write([]byte(plaintext)) + sumB := SM3(bufB.Bytes()) + bufB.Reset() + + // 9, 10 + var i int + for i = len(plaintext); i > MIXCHARS; i -= MIXCHARS { + bufA.Write(sumB[:MIXCHARS]) + } + bufA.Write(sumB[:i]) + + // 11 + for i = len(plaintext); i > 0; i >>= 1 { + if i%2 == 0 { + bufA.Write([]byte(plaintext)) + } else { + bufA.Write(sumB[:]) + } + } + + // 12 + sumA := SM3(bufA.Bytes()) + bufA.Reset() + + // 13, 14, 15 + bufDP := bufA + for range []byte(plaintext) { + bufDP.Write([]byte(plaintext)) + } + sumDP := SM3(bufDP.Bytes()) + bufDP.Reset() + + // 16 + p := make([]byte, 0, sha256.Size) + for i = len(plaintext); i > 0; i -= MIXCHARS { + if i > MIXCHARS { + p = append(p, sumDP[:]...) + } else { + p = append(p, sumDP[0:i]...) + } + } + + // 17, 18, 19 + bufDS := bufA + for i = 0; i < 16+int(sumA[0]); i++ { + bufDS.Write(salt) + } + sumDS := SM3(bufDS.Bytes()) + bufDS.Reset() + + // 20 + s := make([]byte, 0, 32) + for i = len(salt); i > 0; i -= MIXCHARS { + if i > MIXCHARS { + s = append(s, sumDS[:]...) + } else { + s = append(s, sumDS[0:i]...) + } + } + + // 21 + bufC := bufA + var sumC []byte + for i = 0; i < iterations; i++ { + bufC.Reset() + if i&1 != 0 { + bufC.Write(p) + } else { + bufC.Write(sumA[:]) + } + if i%3 != 0 { + bufC.Write(s) + } + if i%7 != 0 { + bufC.Write(p) + } + if i&1 != 0 { + bufC.Write(sumA[:]) + } else { + bufC.Write(p) + } + sumC = SM3(bufC.Bytes()) + sumA = sumC + } + + // 22 + buf := bytes.NewBuffer(make([]byte, 0, 100)) + buf.Write([]byte{'$', 'A', '$'}) + rounds := fmt.Sprintf("%03d", iterations/ITERATION_MULTIPLIER) + buf.Write([]byte(rounds)) + buf.Write([]byte{'$'}) + buf.Write(salt) + + b64From24bit([]byte{sumC[0], sumC[10], sumC[20]}, 4, buf) + b64From24bit([]byte{sumC[21], sumC[1], sumC[11]}, 4, buf) + b64From24bit([]byte{sumC[12], sumC[22], sumC[2]}, 4, buf) + b64From24bit([]byte{sumC[3], sumC[13], sumC[23]}, 4, buf) + b64From24bit([]byte{sumC[24], sumC[4], sumC[14]}, 4, buf) + b64From24bit([]byte{sumC[15], sumC[25], sumC[5]}, 4, buf) + b64From24bit([]byte{sumC[6], sumC[16], sumC[26]}, 4, buf) + b64From24bit([]byte{sumC[27], sumC[7], sumC[17]}, 4, buf) + b64From24bit([]byte{sumC[18], sumC[28], sumC[8]}, 4, buf) + b64From24bit([]byte{sumC[9], sumC[19], sumC[29]}, 4, buf) + b64From24bit([]byte{0, sumC[31], sumC[30]}, 3, buf) + + return buf.String() +} + +// CheckSM3Password checks if a sm3 authentication string matches a password func CheckSM3Password(pwhash []byte, password string) (bool, error) { - debug.PrintStack() - h := New() - h.Write([]byte(password)) - sum := h.Sum(nil) - return bytes.Equal(pwhash, sum), nil + pwhashParts := bytes.Split(pwhash, []byte("$")) + if len(pwhashParts) != 4 { + return false, errors.New("failed to decode hash parts") + } + + hashType := string(pwhashParts[1]) + if hashType != "A" { + return false, errors.New("digest type is incompatible") + } + + iterations, err := strconv.Atoi(string(pwhashParts[2])) + if err != nil { + return false, errors.New("failed to decode iterations") + } + iterations = iterations * ITERATION_MULTIPLIER + salt := pwhashParts[3][:SALT_LENGTH] + + newHash := sm3crypt(password, salt, iterations) + + return bytes.Equal(pwhash, []byte(newHash)), nil } func NewSM3Password(pwd string) string { - debug.PrintStack() - h := New() - h.Write([]byte(pwd)) - sum := h.Sum(nil) - return string(sum) + salt := make([]byte, SALT_LENGTH) + rand.Read(salt) + + // Restrict to 7-bit to avoid multi-byte UTF-8 + for i := range salt { + salt[i] = salt[i] &^ 128 + for salt[i] == 36 || salt[i] == 0 { // '$' or NUL + newval := make([]byte, 1) + rand.Read(newval) + salt[i] = newval[0] &^ 128 + } + } + + return sm3crypt(pwd, salt, 5*ITERATION_MULTIPLIER) } diff --git a/parser/auth/sm3_test.go b/parser/auth/sm3_test.go index e4affa47d29aa..6aa75bea77c27 100644 --- a/parser/auth/sm3_test.go +++ b/parser/auth/sm3_test.go @@ -14,26 +14,68 @@ package auth import ( - "fmt" + "encoding/hex" "testing" "github.com/stretchr/testify/require" ) -func TestCheckSM3Password(t *testing.T) { - pwd1 := "test" - pwd2 := "no_pass" - val := "55e12e91650d2fec56ec74e1d3e4ddbfce2ef3a65890c2a19ecf88a307e76a23" +var foobarPwdSM3Hash, _ = hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") - sm3pwd := NewSM3Password(pwd1) - require.Equal(t, val, fmt.Sprintf("%x", sm3pwd)) - r, err := CheckSM3Password([]byte(sm3pwd), pwd1) +func TestCheckSM3PasswordGood(t *testing.T) { + pwd := "foobar" + r, err := CheckSM3Password(foobarPwdSM3Hash, pwd) require.NoError(t, err) require.True(t, r) +} - sm3pwd = NewSM3Password(pwd2) - require.NotEqual(t, val, fmt.Sprintf("%x", sm3pwd)) - r, err = CheckSM3Password([]byte(sm3pwd), pwd1) +func TestCheckSM3PasswordBad(t *testing.T) { + pwd := "not_foobar" + pwhash, _ := hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") + r, err := CheckSM3Password(pwhash, pwd) require.NoError(t, err) require.False(t, r) } + +func TestCheckSM3PasswordShort(t *testing.T) { + pwd := "not_foobar" + pwhash, _ := hex.DecodeString("aaaaaaaa") + _, err := CheckSM3Password(pwhash, pwd) + require.Error(t, err) +} + +func TestCheckSM3PasswordDigestTypeIncompatible(t *testing.T) { + pwd := "not_foobar" + pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") + _, err := CheckSM3Password(pwhash, pwd) + require.Error(t, err) +} + +func TestCheckSM3PasswordIterationsInvalid(t *testing.T) { + pwd := "not_foobar" + pwhash, _ := hex.DecodeString("24412430304124031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") + _, err := CheckSM3Password(pwhash, pwd) + require.Error(t, err) +} + +func TestNewSM3Password(t *testing.T) { + pwd := "testpwd" + pwhash := NewSM3Password(pwd) + r, err := CheckSM3Password([]byte(pwhash), pwd) + require.NoError(t, err) + require.True(t, r) + + for r := range pwhash { + require.Less(t, pwhash[r], uint8(128)) + require.NotEqual(t, pwhash[r], 0) // NUL + require.NotEqual(t, pwhash[r], 36) // '$' + } +} + +func BenchmarkSM3Password(b *testing.B) { + for i := 0; i < b.N; i++ { + m, err := CheckSM3Password(foobarPwdSM3Hash, "foobar") + require.Nil(b, err) + require.True(b, m) + } +} diff --git a/server/conn.go b/server/conn.go index f147e8d060b66..d694926afd847 100644 --- a/server/conn.go +++ b/server/conn.go @@ -730,7 +730,10 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } case mysql.AuthSM3Password: - // TODO + resp.Auth, err = cc.authSM3(ctx) + if err != nil { + return err + } case mysql.AuthNativePassword: case mysql.AuthSocket: default: @@ -806,6 +809,37 @@ func (cc *clientConn) authSha(ctx context.Context) ([]byte, error) { return bytes.Trim(data, "\x00"), nil } +// authSM3 implements the sm3_password specific part of the protocol. +func (cc *clientConn) authSM3(ctx context.Context) ([]byte, error) { + + const ( + SM3Command = 1 + RequestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. + FastAuthOk = 3 + FastAuthFail = 4 + ) + + // Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet. + // This triggers the client to send the full response. + err := cc.writePacket([]byte{0, 0, 0, 0, SM3Command, FastAuthFail}) + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet flush failed", zap.Error(err)) + return nil, err + } + + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet read failed", zap.Error(err)) + return nil, err + } + return bytes.Trim(data, "\x00"), nil +} + func (cc *clientConn) SessionStatusToString() string { status := cc.ctx.Status() inTxn, autoCommit := 0, 0 From 3aba7979dd87f260b8769f8e7813f4d340c667d5 Mon Sep 17 00:00:00 2001 From: CbcWestwolf <1004626265@qq.com> Date: Fri, 15 Jul 2022 16:51:29 +0800 Subject: [PATCH 05/29] Update parser/auth/sm3.go Co-authored-by: djshow832 --- parser/auth/sm3.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index cb83efc6e9b68..e2444b6db6960 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -1,4 +1,4 @@ -// Copyright 2021 PingCAP, Inc. +// Copyright 2022 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 711736bf49951b5644dcf5f4d3ee8570a77ace31 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Fri, 15 Jul 2022 16:52:33 +0800 Subject: [PATCH 06/29] Update --- parser/auth/sm3.go | 4 ++-- parser/auth/sm3_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index e2444b6db6960..c4f5a735e2bf5 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -306,7 +306,7 @@ func sm3crypt(plaintext string, salt []byte, iterations int) string { // 22 buf := bytes.NewBuffer(make([]byte, 0, 100)) - buf.Write([]byte{'$', 'A', '$'}) + buf.Write([]byte{'$', 'B', '$'}) rounds := fmt.Sprintf("%03d", iterations/ITERATION_MULTIPLIER) buf.Write([]byte(rounds)) buf.Write([]byte{'$'}) @@ -335,7 +335,7 @@ func CheckSM3Password(pwhash []byte, password string) (bool, error) { } hashType := string(pwhashParts[1]) - if hashType != "A" { + if hashType != "B" { return false, errors.New("digest type is incompatible") } diff --git a/parser/auth/sm3_test.go b/parser/auth/sm3_test.go index 6aa75bea77c27..42d497d8f44e6 100644 --- a/parser/auth/sm3_test.go +++ b/parser/auth/sm3_test.go @@ -1,4 +1,4 @@ -// Copyright 2021 PingCAP, Inc. +// Copyright 2022 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" ) -var foobarPwdSM3Hash, _ = hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") +var foobarPwdSM3Hash, _ = hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") func TestCheckSM3PasswordGood(t *testing.T) { pwd := "foobar" @@ -31,7 +31,7 @@ func TestCheckSM3PasswordGood(t *testing.T) { func TestCheckSM3PasswordBad(t *testing.T) { pwd := "not_foobar" - pwhash, _ := hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") + pwhash, _ := hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") r, err := CheckSM3Password(pwhash, pwd) require.NoError(t, err) require.False(t, r) @@ -46,7 +46,7 @@ func TestCheckSM3PasswordShort(t *testing.T) { func TestCheckSM3PasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" - pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") + pwhash, _ := hex.DecodeString("24432430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckSM3Password(pwhash, pwd) require.Error(t, err) } From 0b79e7332fb1bbdb22e0827a9da37ef81dee076b Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Fri, 15 Jul 2022 16:58:17 +0800 Subject: [PATCH 07/29] Update --- parser/auth/sm3.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index c4f5a735e2bf5..f3185709fbc3a 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -200,13 +200,6 @@ func SM3(data []byte) []byte { return h.Sum(nil) } -func SM3String(pwd string) string { - var h sm3 - h.Reset() - h.Write([]byte(pwd)) - return string(h.Sum(nil)) -} - func sm3crypt(plaintext string, salt []byte, iterations int) string { // Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt @@ -351,6 +344,7 @@ func CheckSM3Password(pwhash []byte, password string) (bool, error) { return bytes.Equal(pwhash, []byte(newHash)), nil } +// NewSM3Password creates a new SM3 password hash func NewSM3Password(pwd string) string { salt := make([]byte, SALT_LENGTH) rand.Read(salt) From 45340e367e8c5d36742361462d83ac8ab86c8643 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Fri, 15 Jul 2022 17:19:45 +0800 Subject: [PATCH 08/29] Update bazel --- br/pkg/glue/BUILD.bazel | 1 + br/pkg/task/BUILD.bazel | 2 +- br/pkg/utils/BUILD.bazel | 10 ++++++++++ store/gcworker/BUILD.bazel | 1 + telemetry/BUILD.bazel | 1 + 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/br/pkg/glue/BUILD.bazel b/br/pkg/glue/BUILD.bazel index 812b2b2c8b6a7..e51549b578927 100644 --- a/br/pkg/glue/BUILD.bazel +++ b/br/pkg/glue/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "//domain", "//kv", "//parser/model", + "//sessionctx", "@com_github_fatih_color//:color", "@com_github_pingcap_log//:log", "@com_github_tikv_pd_client//:client", diff --git a/br/pkg/task/BUILD.bazel b/br/pkg/task/BUILD.bazel index b01d47de6de22..4acaf84014b03 100644 --- a/br/pkg/task/BUILD.bazel +++ b/br/pkg/task/BUILD.bazel @@ -39,6 +39,7 @@ go_library( "//statistics/handle", "//types", "//util/mathutil", + "//util/sqlexec", "//util/table-filter", "@com_github_docker_go_units//:go-units", "@com_github_fatih_color//:color", @@ -78,7 +79,6 @@ go_test( deps = [ "//br/pkg/conn", "//br/pkg/metautil", - "//br/pkg/pdutil", "//br/pkg/restore", "//br/pkg/storage", "//br/pkg/stream", diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index b708ec2fa7979..f4fabeb825882 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -27,10 +27,13 @@ go_library( "//br/pkg/logutil", "//br/pkg/metautil", "//errno", + "//kv", "//parser/model", "//parser/mysql", "//parser/terror", + "//sessionctx", "//util", + "//util/sqlexec", "@com_github_cheggaaa_pb_v3//:pb", "@com_github_google_uuid//:uuid", "@com_github_pingcap_errors//:errors", @@ -58,6 +61,7 @@ go_test( name = "utils_test", srcs = [ "backoff_test.go", + "db_test.go", "env_test.go", "json_test.go", "key_test.go", @@ -72,11 +76,17 @@ go_test( "//br/pkg/errors", "//br/pkg/metautil", "//br/pkg/storage", + "//parser/ast", "//parser/model", + "//parser/mysql", "//statistics/handle", "//tablecodec", "//testkit/testsetup", + "//types", + "//util/chunk", + "//util/sqlexec", "@com_github_golang_protobuf//proto", + "@com_github_pingcap_errors//:errors", "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_kvproto//pkg/encryptionpb", "@com_github_stretchr_testify//require", diff --git a/store/gcworker/BUILD.bazel b/store/gcworker/BUILD.bazel index 9ac77dd4ceb30..f43e28c02958b 100644 --- a/store/gcworker/BUILD.bazel +++ b/store/gcworker/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/pingcap/tidb/store/gcworker", visibility = ["//visibility:public"], deps = [ + "//br/pkg/utils", "//ddl", "//ddl/label", "//ddl/placement", diff --git a/telemetry/BUILD.bazel b/telemetry/BUILD.bazel index 05df184817252..7cce3b2401d35 100644 --- a/telemetry/BUILD.bazel +++ b/telemetry/BUILD.bazel @@ -18,6 +18,7 @@ go_library( importpath = "github.com/pingcap/tidb/telemetry", visibility = ["//visibility:public"], deps = [ + "//br/pkg/utils", "//config", "//domain/infosync", "//infoschema", From e64e90203d47e656278b45c9bc1b66b7493e913c Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 10:56:16 +0800 Subject: [PATCH 09/29] Fix --- parser/mysql/const.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/parser/mysql/const.go b/parser/mysql/const.go index 4a9d617b50960..37d06b95518e5 100644 --- a/parser/mysql/const.go +++ b/parser/mysql/const.go @@ -232,8 +232,11 @@ const MaxTypeSetMembers = 64 // PWDHashLen is the length of mysql_native_password's hash. const PWDHashLen = 40 // excluding the '*' + // SHAPWDHashLen is the length of sha256_password's hash. const SHAPWDHashLen = 70 + +// SM3PWDHashLen is the length of sm3_password's hash. const SM3PWDHashLen = 64 // Command2Str is the command information to command name. From edd2ae6a617b6a6313b8c01ba2d1a420cf9933c4 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 11:10:45 +0800 Subject: [PATCH 10/29] Fix --- parser/auth/sm3.go | 90 ++++++++++++++++++++++--------------------- parser/mysql/const.go | 2 +- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index f3185709fbc3a..706a7a1996782 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -32,21 +32,23 @@ type sm3 struct { digest [8]uint32 // digest represents the partial evaluation of V length uint64 // length of the message unhandleMsg []byte + blockSize int + size int } -func (sm3 *sm3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z } +func ff0(x, y, z uint32) uint32 { return x ^ y ^ z } -func (sm3 *sm3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } +func ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } -func (sm3 *sm3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z } +func gg0(x, y, z uint32) uint32 { return x ^ y ^ z } -func (sm3 *sm3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } +func gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } -func (sm3 *sm3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) } +func p0(x uint32) uint32 { return x ^ leftRotate(x, 9) ^ leftRotate(x, 17) } -func (sm3 *sm3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) } +func p1(x uint32) uint32 { return x ^ leftRotate(x, 15) ^ leftRotate(x, 23) } -func (sm3 *sm3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } +func leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } func (sm3 *sm3) pad() []byte { msg := sm3.unhandleMsg @@ -83,48 +85,48 @@ func (sm3 *sm3) update(msg []byte) [8]uint32 { w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)]) } for i := 16; i < 68; i++ { - w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6] + w[i] = p1(w[i-16]^w[i-9]^leftRotate(w[i-3], 15)) ^ leftRotate(w[i-13], 7) ^ w[i-6] } for i := 0; i < 64; i++ { w1[i] = w[i] ^ w[i+4] } - A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h + a1, b1, c1, d1, e1, f1, g1, h1 := a, b, c, d, e, f, g, h for i := 0; i < 16; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) + ss1 := leftRotate(leftRotate(a1, 12)+e1+leftRotate(0x79cc4519, uint32(i)), 7) + ss2 := ss1 ^ leftRotate(a1, 12) + tt1 := ff0(a1, b1, c1) + d1 + ss2 + w1[i] + tt2 := gg0(e1, f1, g1) + h1 + ss1 + w[i] + d1 = c1 + c1 = leftRotate(b1, 9) + b1 = a1 + a1 = tt1 + h1 = g1 + g1 = leftRotate(f1, 19) + f1 = e1 + e1 = p0(tt2) } for i := 16; i < 64; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) + ss1 := leftRotate(leftRotate(a1, 12)+e1+leftRotate(0x7a879d8a, uint32(i)), 7) + ss2 := ss1 ^ leftRotate(a1, 12) + tt1 := ff1(a1, b1, c1) + d1 + ss2 + w1[i] + tt2 := gg1(e1, f1, g1) + h1 + ss1 + w[i] + d1 = c1 + c1 = leftRotate(b1, 9) + b1 = a1 + a1 = tt1 + h1 = g1 + g1 = leftRotate(f1, 19) + f1 = e1 + e1 = p0(tt2) } - a ^= A - b ^= B - c ^= C - d ^= D - e ^= E - f ^= F - g ^= G - h ^= H + a ^= a1 + b ^= b1 + c ^= c1 + d ^= d1 + e ^= e1 + f ^= f1 + g ^= g1 + h ^= h1 msg = msg[64:] } var digest [8]uint32 @@ -135,10 +137,10 @@ func (sm3 *sm3) update(msg []byte) [8]uint32 { // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount of data, // but it may operate more efficiently if all writes are a multiple of the block size. -func (sm3 *sm3) BlockSize() int { return 64 } +func (sm3 *sm3) BlockSize() int { return sm3.blockSize } // Size returns the number of bytes Sum will return. -func (sm3 *sm3) Size() int { return 32 } +func (sm3 *sm3) Size() int { return sm3.size } // Reset clears the internal state by zeroing bytes in the state buffer. // This can be skipped for a newly-created hash state; the default zero-allocated state is correct. @@ -197,6 +199,8 @@ func SM3(data []byte) []byte { var h sm3 h.Reset() h.Write(data) + h.blockSize = 64 + h.size = 32 return h.Sum(nil) } diff --git a/parser/mysql/const.go b/parser/mysql/const.go index 37d06b95518e5..076982b971817 100644 --- a/parser/mysql/const.go +++ b/parser/mysql/const.go @@ -237,7 +237,7 @@ const PWDHashLen = 40 // excluding the '*' const SHAPWDHashLen = 70 // SM3PWDHashLen is the length of sm3_password's hash. -const SM3PWDHashLen = 64 +const SM3PWDHashLen = 70 // Command2Str is the command information to command name. var Command2Str = map[byte]string{ From 59cc76ee9a7668a3221e02262817e2b64bee9ce2 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 11:28:27 +0800 Subject: [PATCH 11/29] Fix --- parser/auth/sm3.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 706a7a1996782..7d5f9e7f680f3 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -157,6 +157,8 @@ func (sm3 *sm3) Reset() { sm3.length = 0 sm3.unhandleMsg = []byte{} + sm3.blockSize = 64 + sm3.size = 32 } // Write (via the embedded io.Writer interface) adds more data to the running hash. @@ -199,8 +201,6 @@ func SM3(data []byte) []byte { var h sm3 h.Reset() h.Write(data) - h.blockSize = 64 - h.size = 32 return h.Sum(nil) } From 6b5cb339b873e6d210d3ace7e3f68709a23b03b3 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 15:59:23 +0800 Subject: [PATCH 12/29] Support bulitin function SM3(str) --- executor/reload_expr_pushdown_blacklist.go | 1 + expression/builtin.go | 1 + expression/builtin_convert_charset.go | 2 +- expression/builtin_encryption.go | 46 +++++++++++++++++++ expression/builtin_encryption_vec.go | 33 +++++++++++++ expression/builtin_encryption_vec_test.go | 3 ++ expression/collation.go | 2 +- expression/typeinfer_test.go | 19 ++++++++ parser/ast/functions.go | 1 + parser/auth/sm3.go | 11 ++++- parser/parser_test.go | 1 + .../pessimistictest/pessimistic_test.go | 4 +- 12 files changed, 118 insertions(+), 6 deletions(-) diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index c32f84c957e1e..5d2a3bc558021 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -304,6 +304,7 @@ var funcName2Alias = map[string]string{ "sha1": ast.SHA1, "sha": ast.SHA, "sha2": ast.SHA2, + "sm3": ast.SM3, "uncompress": ast.Uncompress, "uncompressed_length": ast.UncompressedLength, "validate_password_strength": ast.ValidatePasswordStrength, diff --git a/expression/builtin.go b/expression/builtin.go index 772ef5c5d48f1..cc1dc0d059a23 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -798,6 +798,7 @@ var funcs = map[string]functionClass{ ast.SHA1: &sha1FunctionClass{baseFunctionClass{ast.SHA1, 1, 1}}, ast.SHA: &sha1FunctionClass{baseFunctionClass{ast.SHA, 1, 1}}, ast.SHA2: &sha2FunctionClass{baseFunctionClass{ast.SHA2, 2, 2}}, + ast.SM3: &sm3FunctionClass{baseFunctionClass{ast.SM3, 1, 1}}, ast.Uncompress: &uncompressFunctionClass{baseFunctionClass{ast.Uncompress, 1, 1}}, ast.UncompressedLength: &uncompressedLengthFunctionClass{baseFunctionClass{ast.UncompressedLength, 1, 1}}, ast.ValidatePasswordStrength: &validatePasswordStrengthFunctionClass{baseFunctionClass{ast.ValidatePasswordStrength, 1, 1}}, diff --git a/expression/builtin_convert_charset.go b/expression/builtin_convert_charset.go index 9fafbd36f3117..1296ed33632cd 100644 --- a/expression/builtin_convert_charset.go +++ b/expression/builtin_convert_charset.go @@ -277,7 +277,7 @@ var convertActionMap = map[funcProp][]string{ ast.ASCII, ast.BitLength, ast.Hex, ast.Length, ast.OctetLength, ast.ToBase64, /* encrypt functions */ ast.AesDecrypt, ast.Decode, ast.Encode, ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, - ast.SHA2, ast.Compress, ast.AesEncrypt, + ast.SHA2, ast.SM3, ast.Compress, ast.AesEncrypt, }, funcPropAuto: { /* string functions */ ast.Concat, ast.ConcatWS, ast.ExportSet, ast.Field, ast.FindInSet, diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 4229f03402422..2f2a1ff1eb95b 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -711,6 +711,52 @@ func (b *builtinSHA2Sig) Clone() builtinFunc { return newSig } +type sm3FunctionClass struct { + baseFunctionClass +} + +func (c *sm3FunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) + if err != nil { + return nil, err + } + charset, collate := ctx.GetSessionVars().GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + bf.tp.SetFlen(40) + sig := &builtinSM3Sig{bf} + //sig.setPbCode(tipb.ScalarFuncSig_SM3) // TODO + return sig, nil +} + +type builtinSM3Sig struct { + baseBuiltinFunc +} + +func (b *builtinSM3Sig) Clone() builtinFunc { + newSig := &builtinSM3Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals SM3(str). +// The value is returned as a string of 70 hexadecimal digits, or NULL if the argument was NULL. +func (b *builtinSM3Sig) evalString(row chunk.Row) (string, bool, error) { + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", isNull, err + } + hasher := auth.NewSM3() + _, err = hasher.Write([]byte(str)) + if err != nil { + return "", true, err + } + return fmt.Sprintf("%x", hasher.Sum(nil)), false, nil +} + // Supported hash length of SHA-2 family const ( SHA0 = 0 diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index 1c2124b8c3001..9736ca4b3ea8e 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -516,6 +516,39 @@ func (b *builtinSHA2Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column) return nil } +func (b *builtinSM3Sig) vectorized() bool { + return true +} + +// vecEvalString evals SM3(str). +func (b *builtinSM3Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { + return err + } + result.ReserveString(n) + hasher := auth.NewSM3() + for i := 0; i < n; i++ { + if buf.IsNull(i) { + result.AppendNull() + continue + } + str := buf.GetBytes(i) + _, err = hasher.Write(str) + if err != nil { + return err + } + result.AppendString(fmt.Sprintf("%x", hasher.Sum(nil))) + hasher.Reset() + } + return nil +} + func (b *builtinCompressSig) vectorized() bool { return true } diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index c7cb9d7f58a11..c6caa1eb60d51 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -66,6 +66,9 @@ var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETInt}, geners: []dataGenerator{newRandLenStrGener(10, 20), newRangeInt64Gener(SHA384, SHA384+1)}}, {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETInt}, geners: []dataGenerator{newRandLenStrGener(10, 20), newRangeInt64Gener(SHA512, SHA512+1)}}, }, + ast.SM3: { + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}}, + }, ast.Encode: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}}, }, diff --git a/expression/collation.go b/expression/collation.go index f24d1601571fa..922158d830f59 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -272,7 +272,7 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion: chs, coll := charset.GetDefaultCharsetAndCollate() return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil - case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2: + case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2, ast.SM3: // should return ASCII repertoire, MySQL's doc says it depends on character_set_connection, but it not true from its source code. ec = &ExprCollation{Coer: CoercibilityCoercible, Repe: ASCII} ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 04b09303712af..eff9c603569ab 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -982,6 +982,25 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase { {"sha2('1234' , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, {"sha2(1234 , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength}, + {"sm3(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_datetime )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_time_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_timestamp_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_char )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_varchar )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_text_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_binary )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_varbinary )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_blob_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"sm3('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, + {"sm3(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength}, + {"AES_ENCRYPT(c_int_d, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_char, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, {"AES_ENCRYPT(c_varchar, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength}, diff --git a/parser/ast/functions.go b/parser/ast/functions.go index ae61e06682656..e53e16eb2d922 100644 --- a/parser/ast/functions.go +++ b/parser/ast/functions.go @@ -309,6 +309,7 @@ const ( SHA1 = "sha1" SHA = "sha" SHA2 = "sha2" + SM3 = "sm3" Uncompress = "uncompress" UncompressedLength = "uncompressed_length" ValidatePasswordStrength = "validate_password_strength" diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 7d5f9e7f680f3..6b3016df44bc4 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -25,6 +25,7 @@ import ( "encoding/binary" "errors" "fmt" + "hash" "strconv" ) @@ -196,10 +197,16 @@ func (sm3 *sm3) Sum(in []byte) []byte { return out } -// SM3 returns the sm3 checksum of the data. -func SM3(data []byte) []byte { +// NewSM3 returns a new hash.Hash computing the SM3 checksum. +func NewSM3() hash.Hash { var h sm3 h.Reset() + return &h +} + +// SM3 returns the sm3 checksum of the data. +func SM3(data []byte) []byte { + h := NewSM3() h.Write(data) return h.Sum(nil) } diff --git a/parser/parser_test.go b/parser/parser_test.go index ee5c9fc7e384c..b18f6b7751efe 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -2113,6 +2113,7 @@ func TestBuiltin(t *testing.T) { {`SELECT SHA1('abc');`, true, "SELECT SHA1(_UTF8MB4'abc')"}, {`SELECT SHA('abc');`, true, "SELECT SHA(_UTF8MB4'abc')"}, {`SELECT SHA2('abc', 224);`, true, "SELECT SHA2(_UTF8MB4'abc', 224)"}, + {`SELECT SM3('abc');`, true, "SELECT SM3(_UTF8MB4'abc')"}, {`SELECT UNCOMPRESS('any string');`, true, "SELECT UNCOMPRESS(_UTF8MB4'any string')"}, {`SELECT UNCOMPRESSED_LENGTH(@compressed_string);`, true, "SELECT UNCOMPRESSED_LENGTH(@`compressed_string`)"}, {`SELECT VALIDATE_PASSWORD_STRENGTH(@str);`, true, "SELECT VALIDATE_PASSWORD_STRENGTH(@`str`)"}, diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index 58bb2b73a86f5..20dbc2283d81e 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -2476,7 +2476,7 @@ func TestAmendForUniqueIndex(t *testing.T) { tk2.MustExec("insert into t1 values(1, 1, 1);") tk2.MustExec("insert into t1 values(2, 2, 2);") - // New value has duplicates. + // NewSM3 value has duplicates. tk.MustExec("begin pessimistic") tk.MustExec("insert into t1 values(3, 3, 3)") tk.MustExec("insert into t1 values(4, 4, 3)") @@ -2485,7 +2485,7 @@ func TestAmendForUniqueIndex(t *testing.T) { tk2.MustExec("alter table t1 drop index uk1") tk2.MustExec("admin check table t1") - // New values has duplicates with old values. + // NewSM3 values has duplicates with old values. tk.MustExec("begin pessimistic") tk.MustExec("insert into t1 values(3, 3, 3)") tk.MustExec("insert into t1 values(4, 4, 1)") From a98c4e22aa2579da82e8675a3c43f5d97ddec8b4 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 16:19:37 +0800 Subject: [PATCH 13/29] Add license from Suzhou Tongji Fintech Research Institute --- parser/auth/sm3.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 6b3016df44bc4..8c1eefa917a22 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -13,11 +13,6 @@ package auth -// The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml -// This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 -// Some other references: -// https://tools.ietf.org/id/draft-oscca-cfrg-sm3-01.html - import ( "bytes" "crypto/rand" @@ -29,6 +24,24 @@ import ( "strconv" ) +// The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml +// This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 +// Some other references: +// https://tools.ietf.org/id/draft-oscca-cfrg-sm3-01.html + +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + type sm3 struct { digest [8]uint32 // digest represents the partial evaluation of V length uint64 // length of the message From 69d779af6ff7db549118603807da9679a130b7d1 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 16:24:10 +0800 Subject: [PATCH 14/29] Fix --- executor/showtest/show_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/executor/showtest/show_test.go b/executor/showtest/show_test.go index 2994069bcb0f6..263cf33e7cf7c 100644 --- a/executor/showtest/show_test.go +++ b/executor/showtest/show_test.go @@ -1445,7 +1445,7 @@ func TestShowBuiltin(t *testing.T) { res := tk.MustQuery("show builtins;") require.NotNil(t, res) rows := res.Rows() - const builtinFuncNum = 275 + const builtinFuncNum = 276 require.Equal(t, len(rows), builtinFuncNum) require.Equal(t, rows[0][0].(string), "abs") require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek") From cf565a8eb89ba1c0e5d7b30759658d0b9d26818a Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 18 Jul 2022 21:17:43 +0800 Subject: [PATCH 15/29] Add test for builtin SM3() --- expression/integration_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/expression/integration_test.go b/expression/integration_test.go index b3dc43fe084d1..1fd73b4776d2c 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1029,6 +1029,22 @@ func TestEncryptionBuiltin(t *testing.T) { result = tk.MustQuery("select sha2('123', 512), sha2(123, 512), sha2('', 512), sha2('你儽', 224), sha2(NULL, 256), sha2('foo', 123)") result.Check(testkit.Rows(`3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e e91f006ed4e0882de2f6a3c96ec228a6a5c715f356d00091bce842b5 `)) + // for sm3 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") + tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`) + result = tk.MustQuery("select sm3(a), sm3(b), sm3(c), sm3(d), sm3(e), sm3(f), sm3(g), sm3(h), sm3(i) from t") + result.Check(testkit.Rows("a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e b01f6234a2c1d98af2d8bfb79a8c95677c6e9f5750eb756890f29b33b712f804 8485b2ccde69acf41e333e8fba2f55a1b3556e1a42443095235db1d5c78b25d1 f71ab1aad211e14a47b549e8df55b627c36fa75c1aa75b9682cccae2de00babc f4051d239b766c4111e92979aa31af0b35def053646e347bc41e8b73cfd080bc d42cb1657149a8057cef0ba0ededef7f23c9a2f133bfd286ad0f4a6a8bdb5cb2 19dfccdab83e610f04c414a96edb45007b9a022af01473fccf2073b546ad092e 5e0fb8467c33dae5879fb296c9766c78b0a6fc966372f76ac000cc1fcafc2876")) + result = tk.MustQuery("select sm3('123'), sm3(123), sm3(''), sm3('你儽'), sm3(NULL)") + result.Check(testkit.Rows(`6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b 78e5c78c5322ca174089e58dc7790acf8ce9d542bee6ae4a5a0797d5e356be61 `)) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") + tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`) + result = tk.MustQuery("select sm3(a), sm3(b), sm3(c), sm3(d), sm3(e), sm3(f), sm3(g), sm3(h), sm3(i) from t") + result.Check(testkit.Rows("a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e b01f6234a2c1d98af2d8bfb79a8c95677c6e9f5750eb756890f29b33b712f804 8485b2ccde69acf41e333e8fba2f55a1b3556e1a42443095235db1d5c78b25d1 f71ab1aad211e14a47b549e8df55b627c36fa75c1aa75b9682cccae2de00babc f4051d239b766c4111e92979aa31af0b35def053646e347bc41e8b73cfd080bc d42cb1657149a8057cef0ba0ededef7f23c9a2f133bfd286ad0f4a6a8bdb5cb2 19dfccdab83e610f04c414a96edb45007b9a022af01473fccf2073b546ad092e 5e0fb8467c33dae5879fb296c9766c78b0a6fc966372f76ac000cc1fcafc2876")) + result = tk.MustQuery("select sm3('123'), sm3(123), sm3(''), sm3('你儽'), sm3(NULL)") + result.Check(testkit.Rows(`6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b 78e5c78c5322ca174089e58dc7790acf8ce9d542bee6ae4a5a0797d5e356be61 `)) + // for AES_ENCRYPT tk.MustExec("drop table if exists t") tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") From e11e90e42ec11b8c467565f550b76d64a552040b Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 10:20:59 +0800 Subject: [PATCH 16/29] Add test for SM3 --- parser/auth/sm3_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/parser/auth/sm3_test.go b/parser/auth/sm3_test.go index 42d497d8f44e6..a5afd911b29a4 100644 --- a/parser/auth/sm3_test.go +++ b/parser/auth/sm3_test.go @@ -22,6 +22,21 @@ import ( var foobarPwdSM3Hash, _ = hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") +func TestSM3(t *testing.T) { + var testCases [][]string = [][]string{ + {"abc", "66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0"}, + {"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", "debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732"}, + } + var expect []byte + + for _, testCase := range testCases { + text := testCase[0] + expect, _ = hex.DecodeString(testCase[1]) + result := SM3([]byte(text)) + require.Equal(t, expect, result) + } +} + func TestCheckSM3PasswordGood(t *testing.T) { pwd := "foobar" r, err := CheckSM3Password(foobarPwdSM3Hash, pwd) From d893f3dd740f12b60f06803a7d28105dd8452fff Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 10:23:14 +0800 Subject: [PATCH 17/29] Fix --- parser/auth/sm3_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/auth/sm3_test.go b/parser/auth/sm3_test.go index a5afd911b29a4..4bd885c13154d 100644 --- a/parser/auth/sm3_test.go +++ b/parser/auth/sm3_test.go @@ -23,7 +23,7 @@ import ( var foobarPwdSM3Hash, _ = hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") func TestSM3(t *testing.T) { - var testCases [][]string = [][]string{ + testCases := [][]string{ {"abc", "66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0"}, {"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", "debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732"}, } From 9da78465cb07d75bb99ab9be87fab87df27406eb Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 10:47:30 +0800 Subject: [PATCH 18/29] Fix --- br/cmd/br/BUILD.bazel | 1 + br/pkg/backup/BUILD.bazel | 2 ++ br/pkg/conn/BUILD.bazel | 2 ++ br/pkg/gluetikv/BUILD.bazel | 1 + br/pkg/lightning/BUILD.bazel | 1 + br/pkg/restore/BUILD.bazel | 2 ++ br/pkg/stream/BUILD.bazel | 1 + br/pkg/streamhelper/BUILD.bazel | 1 + br/pkg/task/BUILD.bazel | 2 ++ br/pkg/utils/BUILD.bazel | 6 ------ br/pkg/version/BUILD.bazel | 1 + dumpling/export/BUILD.bazel | 1 + store/gcworker/BUILD.bazel | 1 + telemetry/BUILD.bazel | 1 + 14 files changed, 17 insertions(+), 6 deletions(-) diff --git a/br/cmd/br/BUILD.bazel b/br/cmd/br/BUILD.bazel index 5f50c3876765a..e558bf367b4ec 100644 --- a/br/cmd/br/BUILD.bazel +++ b/br/cmd/br/BUILD.bazel @@ -27,6 +27,7 @@ go_library( "//br/pkg/summary", "//br/pkg/task", "//br/pkg/trace", + "//br/pkg/utils", "//br/pkg/version/build", "//config", "//ddl", diff --git a/br/pkg/backup/BUILD.bazel b/br/pkg/backup/BUILD.bazel index c63ff5fb482e1..14f4d1a87c6f9 100644 --- a/br/pkg/backup/BUILD.bazel +++ b/br/pkg/backup/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//br/pkg/rtree", "//br/pkg/storage", "//br/pkg/summary", + "//br/pkg/utils", "//ddl", "//distsql", "//kv", @@ -68,6 +69,7 @@ go_test( "//br/pkg/mock", "//br/pkg/pdutil", "//br/pkg/storage", + "//br/pkg/utils", "//kv", "//parser/model", "//sessionctx/variable", diff --git a/br/pkg/conn/BUILD.bazel b/br/pkg/conn/BUILD.bazel index 733347e168fb3..ad1bcacaaac8d 100644 --- a/br/pkg/conn/BUILD.bazel +++ b/br/pkg/conn/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//br/pkg/glue", "//br/pkg/logutil", "//br/pkg/pdutil", + "//br/pkg/utils", "//br/pkg/version", "//domain", "//kv", @@ -42,6 +43,7 @@ go_test( embed = [":conn"], deps = [ "//br/pkg/pdutil", + "//br/pkg/utils", "//testkit/testsetup", "@com_github_docker_go_units//:go-units", "@com_github_pingcap_errors//:errors", diff --git a/br/pkg/gluetikv/BUILD.bazel b/br/pkg/gluetikv/BUILD.bazel index 879b39b3a952b..7d8c6118604e4 100644 --- a/br/pkg/gluetikv/BUILD.bazel +++ b/br/pkg/gluetikv/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//br/pkg/glue", "//br/pkg/summary", + "//br/pkg/utils", "//br/pkg/version/build", "//config", "//domain", diff --git a/br/pkg/lightning/BUILD.bazel b/br/pkg/lightning/BUILD.bazel index d37012b0c3d3a..99d534762fc69 100644 --- a/br/pkg/lightning/BUILD.bazel +++ b/br/pkg/lightning/BUILD.bazel @@ -24,6 +24,7 @@ go_library( "//br/pkg/lightning/web", "//br/pkg/redact", "//br/pkg/storage", + "//br/pkg/utils", "//br/pkg/version/build", "//util/promutil", "@com_github_pingcap_errors//:errors", diff --git a/br/pkg/restore/BUILD.bazel b/br/pkg/restore/BUILD.bazel index 54df5d446e71c..e89f76d24d54c 100644 --- a/br/pkg/restore/BUILD.bazel +++ b/br/pkg/restore/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//br/pkg/storage", "//br/pkg/stream", "//br/pkg/summary", + "//br/pkg/utils", "//config", "//ddl/util", "//domain", @@ -114,6 +115,7 @@ go_test( "//br/pkg/rtree", "//br/pkg/storage", "//br/pkg/stream", + "//br/pkg/utils", "//infoschema", "//kv", "//meta/autoid", diff --git a/br/pkg/stream/BUILD.bazel b/br/pkg/stream/BUILD.bazel index 10d30053d4533..15ee92d85b2a2 100644 --- a/br/pkg/stream/BUILD.bazel +++ b/br/pkg/stream/BUILD.bazel @@ -18,6 +18,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/storage", "//br/pkg/streamhelper", + "//br/pkg/utils", "//kv", "//meta", "//parser/model", diff --git a/br/pkg/streamhelper/BUILD.bazel b/br/pkg/streamhelper/BUILD.bazel index c7fa8a914fd9b..e3761c9b14361 100644 --- a/br/pkg/streamhelper/BUILD.bazel +++ b/br/pkg/streamhelper/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//br/pkg/logutil", "//br/pkg/redact", "//br/pkg/streamhelper/config", + "//br/pkg/utils", "//config", "//kv", "//metrics", diff --git a/br/pkg/task/BUILD.bazel b/br/pkg/task/BUILD.bazel index 246533cdfc1e2..4acaf84014b03 100644 --- a/br/pkg/task/BUILD.bazel +++ b/br/pkg/task/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//br/pkg/streamhelper", "//br/pkg/streamhelper/config", "//br/pkg/summary", + "//br/pkg/utils", "//br/pkg/version", "//config", "//kv", @@ -81,6 +82,7 @@ go_test( "//br/pkg/restore", "//br/pkg/storage", "//br/pkg/stream", + "//br/pkg/utils", "//config", "//parser/model", "//statistics/handle", diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index c97c453f426c0..631e0e7603ab6 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -32,10 +32,7 @@ go_library( "//parser/model", "//parser/mysql", "//parser/terror", -<<<<<<< HEAD -======= "//parser/types", ->>>>>>> de017e9eea670d050ae5dc7519da11baf632efa0 "//sessionctx", "//util", "//util/sqlexec", @@ -85,10 +82,7 @@ go_test( "//parser/ast", "//parser/model", "//parser/mysql", -<<<<<<< HEAD -======= "//parser/types", ->>>>>>> de017e9eea670d050ae5dc7519da11baf632efa0 "//statistics/handle", "//tablecodec", "//testkit/testsetup", diff --git a/br/pkg/version/BUILD.bazel b/br/pkg/version/BUILD.bazel index b72d763df2096..26bbda74a0d32 100644 --- a/br/pkg/version/BUILD.bazel +++ b/br/pkg/version/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//br/pkg/errors", "//br/pkg/logutil", + "//br/pkg/utils", "//br/pkg/version/build", "@com_github_coreos_go_semver//semver", "@com_github_pingcap_errors//:errors", diff --git a/dumpling/export/BUILD.bazel b/dumpling/export/BUILD.bazel index 3076145c9ddb7..b12dc5f87246b 100644 --- a/dumpling/export/BUILD.bazel +++ b/dumpling/export/BUILD.bazel @@ -28,6 +28,7 @@ go_library( deps = [ "//br/pkg/storage", "//br/pkg/summary", + "//br/pkg/utils", "//br/pkg/version", "//config", "//dumpling/cli", diff --git a/store/gcworker/BUILD.bazel b/store/gcworker/BUILD.bazel index 9ac77dd4ceb30..f43e28c02958b 100644 --- a/store/gcworker/BUILD.bazel +++ b/store/gcworker/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/pingcap/tidb/store/gcworker", visibility = ["//visibility:public"], deps = [ + "//br/pkg/utils", "//ddl", "//ddl/label", "//ddl/placement", diff --git a/telemetry/BUILD.bazel b/telemetry/BUILD.bazel index 0f3a93d35ecaa..15ff2794210fb 100644 --- a/telemetry/BUILD.bazel +++ b/telemetry/BUILD.bazel @@ -18,6 +18,7 @@ go_library( importpath = "github.com/pingcap/tidb/telemetry", visibility = ["//visibility:public"], deps = [ + "//br/pkg/utils", "//config", "//domain/infosync", "//infoschema", From cbd467ef3b8ace572f0fa3b00d455eb9079d02da Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 11:15:24 +0800 Subject: [PATCH 19/29] Fix UT --- executor/showtest/show_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/executor/showtest/show_test.go b/executor/showtest/show_test.go index 263cf33e7cf7c..b0b8d0f6fb398 100644 --- a/executor/showtest/show_test.go +++ b/executor/showtest/show_test.go @@ -1445,7 +1445,7 @@ func TestShowBuiltin(t *testing.T) { res := tk.MustQuery("show builtins;") require.NotNil(t, res) rows := res.Rows() - const builtinFuncNum = 276 + const builtinFuncNum = 277 require.Equal(t, len(rows), builtinFuncNum) require.Equal(t, rows[0][0].(string), "abs") require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek") From 64516b5ec52df46e43b79f727d5cbe8a63f4eb75 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 11:16:16 +0800 Subject: [PATCH 20/29] Fix --- br/pkg/lightning/backend/kv/BUILD.bazel | 1 + br/pkg/lightning/backend/local/BUILD.bazel | 2 ++ br/pkg/lightning/backend/tidb/BUILD.bazel | 1 + br/pkg/lightning/common/BUILD.bazel | 1 + br/pkg/lightning/errormanager/BUILD.bazel | 2 ++ br/pkg/lightning/restore/BUILD.bazel | 1 + 6 files changed, 8 insertions(+) diff --git a/br/pkg/lightning/backend/kv/BUILD.bazel b/br/pkg/lightning/backend/kv/BUILD.bazel index 0994eb48cbf8a..f0b8c5545c330 100644 --- a/br/pkg/lightning/backend/kv/BUILD.bazel +++ b/br/pkg/lightning/backend/kv/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "//br/pkg/lightning/verification", "//br/pkg/logutil", "//br/pkg/redact", + "//br/pkg/utils", "//expression", "//kv", "//meta/autoid", diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index 09477b6ec091e..02358eb492d32 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -33,6 +33,7 @@ go_library( "//br/pkg/membuf", "//br/pkg/pdutil", "//br/pkg/restore", + "//br/pkg/utils", "//br/pkg/version", "//distsql", "//infoschema", @@ -98,6 +99,7 @@ go_test( "//br/pkg/mock", "//br/pkg/pdutil", "//br/pkg/restore", + "//br/pkg/utils", "//br/pkg/version", "//kv", "//parser/mysql", diff --git a/br/pkg/lightning/backend/tidb/BUILD.bazel b/br/pkg/lightning/backend/tidb/BUILD.bazel index 5e6f3c1a546d2..9dd8a9a876c14 100644 --- a/br/pkg/lightning/backend/tidb/BUILD.bazel +++ b/br/pkg/lightning/backend/tidb/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//br/pkg/lightning/log", "//br/pkg/lightning/verification", "//br/pkg/redact", + "//br/pkg/utils", "//br/pkg/version", "//parser/model", "//parser/mysql", diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index 6873746f3afcd..c64a3fdf85654 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "//br/pkg/errors", "//br/pkg/httputil", "//br/pkg/lightning/log", + "//br/pkg/utils", "//errno", "//parser/model", "@com_github_go_sql_driver_mysql//:mysql", diff --git a/br/pkg/lightning/errormanager/BUILD.bazel b/br/pkg/lightning/errormanager/BUILD.bazel index f008eb669eb98..7aea8447865e8 100644 --- a/br/pkg/lightning/errormanager/BUILD.bazel +++ b/br/pkg/lightning/errormanager/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//br/pkg/lightning/config", "//br/pkg/lightning/log", "//br/pkg/redact", + "//br/pkg/utils", "@com_github_jedib0t_go_pretty_v6//table", "@com_github_jedib0t_go_pretty_v6//text", "@com_github_pingcap_errors//:errors", @@ -26,6 +27,7 @@ go_test( deps = [ "//br/pkg/lightning/config", "//br/pkg/lightning/log", + "//br/pkg/utils", "@com_github_data_dog_go_sqlmock//:go-sqlmock", "@com_github_stretchr_testify//require", "@org_uber_go_atomic//:atomic", diff --git a/br/pkg/lightning/restore/BUILD.bazel b/br/pkg/lightning/restore/BUILD.bazel index 633d283627673..afb763512059c 100644 --- a/br/pkg/lightning/restore/BUILD.bazel +++ b/br/pkg/lightning/restore/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//br/pkg/pdutil", "//br/pkg/redact", "//br/pkg/storage", + "//br/pkg/utils", "//br/pkg/version", "//br/pkg/version/build", "//ddl", From 661f7209970c903ed4a5fa47d3451e4ed243d340 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 19 Jul 2022 14:00:47 +0800 Subject: [PATCH 21/29] Fix --- tests/realtikvtest/pessimistictest/pessimistic_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index 20dbc2283d81e..58bb2b73a86f5 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -2476,7 +2476,7 @@ func TestAmendForUniqueIndex(t *testing.T) { tk2.MustExec("insert into t1 values(1, 1, 1);") tk2.MustExec("insert into t1 values(2, 2, 2);") - // NewSM3 value has duplicates. + // New value has duplicates. tk.MustExec("begin pessimistic") tk.MustExec("insert into t1 values(3, 3, 3)") tk.MustExec("insert into t1 values(4, 4, 3)") @@ -2485,7 +2485,7 @@ func TestAmendForUniqueIndex(t *testing.T) { tk2.MustExec("alter table t1 drop index uk1") tk2.MustExec("admin check table t1") - // NewSM3 values has duplicates with old values. + // New values has duplicates with old values. tk.MustExec("begin pessimistic") tk.MustExec("insert into t1 values(3, 3, 3)") tk.MustExec("insert into t1 values(4, 4, 1)") From 1ab815b475ba8aff9eb09ce29b155d09705b54a4 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 21 Jul 2022 17:07:41 +0800 Subject: [PATCH 22/29] Fix --- expression/builtin_encryption_vec.go | 2 +- parser/auth/sm3.go | 3 -- parser/go.sum | 64 ---------------------------- 3 files changed, 1 insertion(+), 68 deletions(-) diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index 9736ca4b3ea8e..789b58e8e07f6 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -529,7 +529,7 @@ func (b *builtinSM3Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column) } defer b.bufAllocator.put(buf) if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { - return err + return errors.Trace(err) } result.ReserveString(n) hasher := auth.NewSM3() diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 8c1eefa917a22..31e2573a1be89 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -83,9 +83,6 @@ func (sm3 *sm3) pad() []byte { msg = append(msg, uint8(sm3.length>>8&0xff)) msg = append(msg, uint8(sm3.length>>0&0xff)) - if len(msg)%64 != 0 { - panic("------sm3 Pad: error msgLen =") - } return msg } diff --git a/parser/go.sum b/parser/go.sum index abd195c60a412..267fe82580882 100644 --- a/parser/go.sum +++ b/parser/go.sum @@ -1,11 +1,7 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/cznic/sortutil v0.0.0-20181122101858-f5f958428db8 h1:LpMLYGyy67BoAFGda1NeOBQwqlv7nUXpm+rIVHGxZZ4= @@ -15,26 +11,8 @@ github.com/cznic/strutil v0.0.0-20171016134553-529a34b1c186/go.mod h1:AHHPPPXTw0 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -49,7 +27,6 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -57,8 +34,6 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= -github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= @@ -72,62 +47,25 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.18.1 h1:CSUJ2mjFszzEWt4CdKISEuChVIXGBn3lAPwkRGyVrc4= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20220428152302-39d4317da171 h1:TfdoLivD44QwvssI9Sv1xwa5DcL5XQr4au4sZ2F2NV4= golang.org/x/exp v0.0.0-20220428152302-39d4317da171/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI= golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -139,8 +77,6 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= modernc.org/fileutil v1.0.0/go.mod h1:JHsWpkrk/CnVV1H/eGlFf85BEpfkrp56ro8nojIq9Q8= modernc.org/golex v1.0.1 h1:EYKY1a3wStt0RzHaH8mdSRNg78Ub0OHxYfCRWw35YtM= modernc.org/golex v1.0.1/go.mod h1:QCA53QtsT1NdGkaZZkF5ezFwk4IXh4BGNafAARTC254= From 9db5b4de43204f3cd9f4bbabe749bc3bc14dfd6d Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Sun, 28 Aug 2022 16:20:43 +0800 Subject: [PATCH 23/29] Fix --- server/conn.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/server/conn.go b/server/conn.go index d63a8455d8547..25924072e220f 100644 --- a/server/conn.go +++ b/server/conn.go @@ -753,17 +753,16 @@ func (cc *clientConn) authSha(ctx context.Context) ([]byte, error) { // authSM3 implements the sm3_password specific part of the protocol. func (cc *clientConn) authSM3(ctx context.Context) ([]byte, error) { - const ( - SM3Command = 1 - RequestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. - FastAuthOk = 3 - FastAuthFail = 4 + sm3Command = 1 + requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. + fastAuthOk = 3 + fastAuthFail = 4 ) // Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet. // This triggers the client to send the full response. - err := cc.writePacket([]byte{0, 0, 0, 0, SM3Command, FastAuthFail}) + err := cc.writePacket([]byte{0, 0, 0, 0, sm3Command, fastAuthFail}) if err != nil { logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err)) return nil, err From 8e091892a422c1f3a6fbd5c54e558d8c1d37cf09 Mon Sep 17 00:00:00 2001 From: CbcWestwolf <1004626265@qq.com> Date: Fri, 2 Sep 2022 14:07:50 +0800 Subject: [PATCH 24/29] Update parser/auth/sm3.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: DaniĆ«l van Eeden --- parser/auth/sm3.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/auth/sm3.go b/parser/auth/sm3.go index 31e2573a1be89..99452fbf0e941 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/sm3.go @@ -27,7 +27,7 @@ import ( // The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml // This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 // Some other references: -// https://tools.ietf.org/id/draft-oscca-cfrg-sm3-01.html +// https://datatracker.ietf.org/doc/draft-sca-cfrg-sm3/ /* Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. From a88861283c2919c759899a16bda8e2e282814aed Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 5 Sep 2022 12:03:41 +0800 Subject: [PATCH 25/29] Update --- executor/simple.go | 6 +++--- parser/ast/misc.go | 4 ++-- parser/mysql/const.go | 4 ++-- privilege/privileges/privileges.go | 8 ++++---- server/conn.go | 17 ++++------------- server/conn_test.go | 26 +++++++++++++------------- sessionctx/variable/sysvar.go | 2 +- 7 files changed, 29 insertions(+), 38 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index 42243beda380a..8bab46026b124 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -852,7 +852,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm } switch authPlugin { - case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password, mysql.AuthSocket: + case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket: default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } @@ -1010,7 +1010,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) spec.AuthOpt.AuthPlugin = authplugin } switch spec.AuthOpt.AuthPlugin { - case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password, mysql.AuthSocket, "": + case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, "": default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } @@ -1497,7 +1497,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error switch authplugin { case mysql.AuthCachingSha2Password: pwd = auth.NewSha2Password(s.Password) - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: pwd = auth.NewSM3Password(s.Password) case mysql.AuthSocket: e.ctx.GetSessionVars().StmtCtx.AppendNote(ErrSetPasswordAuthPlugin.GenWithStackByArgs(u, h)) diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 404a15a318ce6..c719c6d50ffc1 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -1334,7 +1334,7 @@ func (n *UserSpec) EncodedPassword() (string, bool) { switch opt.AuthPlugin { case mysql.AuthCachingSha2Password: return auth.NewSha2Password(opt.AuthString), true - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: return auth.NewSM3Password(opt.AuthString), true case mysql.AuthSocket: return "", true @@ -1354,7 +1354,7 @@ func (n *UserSpec) EncodedPassword() (string, bool) { if len(opt.HashString) != mysql.SHAPWDHashLen { return "", false } - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: if len(opt.HashString) != mysql.SM3PWDHashLen { return "", false } diff --git a/parser/mysql/const.go b/parser/mysql/const.go index 56f8432cef35c..056be1934265a 100644 --- a/parser/mysql/const.go +++ b/parser/mysql/const.go @@ -176,7 +176,7 @@ const ( const ( AuthNativePassword = "mysql_native_password" // #nosec G101 AuthCachingSha2Password = "caching_sha2_password" // #nosec G101 - AuthSM3Password = "sm3_password" // #nosec G101 + AuthTiDBSM3Password = "tidb_sm3_password" // #nosec G101 AuthSocket = "auth_socket" AuthTiDBSessionToken = "tidb_session_token" ) @@ -247,7 +247,7 @@ const PWDHashLen = 40 // excluding the '*' // SHAPWDHashLen is the length of sha256_password's hash. const SHAPWDHashLen = 70 -// SM3PWDHashLen is the length of sm3_password's hash. +// SM3PWDHashLen is the length of tidb_sm3_password's hash. const SM3PWDHashLen = 70 // Command2Str is the command information to command name. diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index eaaa5d0a10c85..d238d6feef26d 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -198,11 +198,11 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { return false } - if record.AuthPlugin == mysql.AuthSM3Password { + if record.AuthPlugin == mysql.AuthTiDBSM3Password { if len(pwd) == mysql.SM3PWDHashLen { return true } - logutil.BgLogger().Error("user password from system DB not like a sm3_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + logutil.BgLogger().Error("user password from system DB not like a tidb_sm3_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) return false } @@ -343,10 +343,10 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse if !authok { return errAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: authOK, err := auth.CheckSM3Password([]byte(pwd), string(authentication)) if err != nil { - logutil.BgLogger().Error("Failed to check sm3_password", zap.Error(err)) + logutil.BgLogger().Error("Failed to check tidb_sm3_password", zap.Error(err)) } if !authOK { diff --git a/server/conn.go b/server/conn.go index 25924072e220f..f4cf518543efa 100644 --- a/server/conn.go +++ b/server/conn.go @@ -670,7 +670,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con if err != nil { return err } - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: resp.Auth, err = cc.authSM3(ctx) if err != nil { return err @@ -702,7 +702,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo switch resp.AuthPlugin { case mysql.AuthCachingSha2Password: - case mysql.AuthSM3Password: + case mysql.AuthTiDBSM3Password: case mysql.AuthNativePassword: case mysql.AuthSocket: case mysql.AuthTiDBSessionToken: @@ -751,18 +751,9 @@ func (cc *clientConn) authSha(ctx context.Context) ([]byte, error) { return bytes.Trim(data, "\x00"), nil } -// authSM3 implements the sm3_password specific part of the protocol. +// authSM3 implements the tidb_sm3_password specific part of the protocol. func (cc *clientConn) authSM3(ctx context.Context) ([]byte, error) { - const ( - sm3Command = 1 - requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. - fastAuthOk = 3 - fastAuthFail = 4 - ) - - // Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet. - // This triggers the client to send the full response. - err := cc.writePacket([]byte{0, 0, 0, 0, sm3Command, fastAuthFail}) + err := cc.writePacket([]byte{0, 0, 0, 0, 1, 4}) if err != nil { logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err)) return nil, err diff --git a/server/conn_test.go b/server/conn_test.go index 17849463dbc8b..3c6edcf30bba7 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1039,7 +1039,7 @@ func TestHandleAuthPlugin(t *testing.T) { require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - // client trying to authenticate with sm3_password + // client trying to authenticate with tidb_sm3_password require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) cc = &clientConn{ connectionID: 1, @@ -1055,7 +1055,7 @@ func TestHandleAuthPlugin(t *testing.T) { } resp = handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, - AuthPlugin: mysql.AuthSM3Password, + AuthPlugin: mysql.AuthTiDBSM3Password, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) @@ -1130,7 +1130,7 @@ func TestHandleAuthPlugin(t *testing.T) { require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - // client trying to authenticate with sm3_password + // client trying to authenticate with tidb_sm3_password require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) cc = &clientConn{ connectionID: 1, @@ -1146,7 +1146,7 @@ func TestHandleAuthPlugin(t *testing.T) { } resp = handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, - AuthPlugin: mysql.AuthSM3Password, + AuthPlugin: mysql.AuthTiDBSM3Password, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) @@ -1222,7 +1222,7 @@ func TestHandleAuthPlugin(t *testing.T) { require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - // client trying to authenticate with sm3_password + // client trying to authenticate with tidb_sm3_password require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) cc = &clientConn{ connectionID: 1, @@ -1238,7 +1238,7 @@ func TestHandleAuthPlugin(t *testing.T) { } resp = handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, - AuthPlugin: mysql.AuthSM3Password, + AuthPlugin: mysql.AuthTiDBSM3Password, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) @@ -1265,8 +1265,8 @@ func TestHandleAuthPlugin(t *testing.T) { require.Error(t, err) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) - // === Target account has sm3_password === - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"sm3_password\")")) + // === Target account has tidb_sm3_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"tidb_sm3_password\")")) // 5.7 or newer client trying to authenticate with mysql_native_password require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) @@ -1288,7 +1288,7 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // 8.0 or newer client trying to authenticate with caching_sha2_password @@ -1311,10 +1311,10 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - // client trying to authenticate with sm3_password + // client trying to authenticate with tidb_sm3_password require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) cc = &clientConn{ connectionID: 1, @@ -1330,11 +1330,11 @@ func TestHandleAuthPlugin(t *testing.T) { } resp = handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, - AuthPlugin: mysql.AuthSM3Password, + AuthPlugin: mysql.AuthTiDBSM3Password, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // MySQL 5.1 or older client, without authplugin support diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index e0b55cfd6aff9..c18b02b612681 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -675,7 +675,7 @@ var defaultSysVars = []*SysVar{ return nil }}, {Scope: ScopeGlobal, Name: SkipNameResolve, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: DefaultAuthPlugin, Value: mysql.AuthNativePassword, Type: TypeEnum, PossibleValues: []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSM3Password}}, + {Scope: ScopeGlobal, Name: DefaultAuthPlugin, Value: mysql.AuthNativePassword, Type: TypeEnum, PossibleValues: []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password}}, {Scope: ScopeGlobal, Name: TiDBPersistAnalyzeOptions, Value: BoolToOnOff(DefTiDBPersistAnalyzeOptions), Type: TypeBool, GetGlobal: func(s *SessionVars) (string, error) { return BoolToOnOff(PersistAnalyzeOptions.Load()), nil From 9a08797d0a066a98232f31311e39038d2d92558d Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 6 Sep 2022 15:32:54 +0800 Subject: [PATCH 26/29] Update --- executor/simple.go | 6 +- expression/builtin_encryption.go | 2 +- expression/builtin_encryption_vec.go | 2 +- parser/ast/misc.go | 6 +- parser/auth/caching_sha2.go | 47 +++-- parser/auth/caching_sha2_test.go | 21 ++- parser/auth/{sm3.go => tidb_sm3.go} | 176 +----------------- parser/auth/{sm3_test.go => tidb_sm3_test.go} | 23 +-- privilege/privileges/privileges.go | 21 +-- 9 files changed, 73 insertions(+), 231 deletions(-) rename parser/auth/{sm3.go => tidb_sm3.go} (59%) rename parser/auth/{sm3_test.go => tidb_sm3_test.go} (73%) diff --git a/executor/simple.go b/executor/simple.go index 8bab46026b124..8ae64826804b3 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -1495,10 +1495,8 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error } var pwd string switch authplugin { - case mysql.AuthCachingSha2Password: - pwd = auth.NewSha2Password(s.Password) - case mysql.AuthTiDBSM3Password: - pwd = auth.NewSM3Password(s.Password) + case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: + pwd = auth.NewHashPassword(s.Password, authplugin) case mysql.AuthSocket: e.ctx.GetSessionVars().StmtCtx.AppendNote(ErrSetPasswordAuthPlugin.GenWithStackByArgs(u, h)) pwd = "" diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 2f2a1ff1eb95b..a206a9d4970bb 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -742,7 +742,7 @@ func (b *builtinSM3Sig) Clone() builtinFunc { return newSig } -// evalString evals SM3(str). +// evalString evals Sm3Hash(str). // The value is returned as a string of 70 hexadecimal digits, or NULL if the argument was NULL. func (b *builtinSM3Sig) evalString(row chunk.Row) (string, bool, error) { str, isNull, err := b.args[0].EvalString(b.ctx, row) diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index 43561e4b9bc8c..e9a1d45ae67be 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -521,7 +521,7 @@ func (b *builtinSM3Sig) vectorized() bool { return true } -// vecEvalString evals SM3(str). +// vecEvalString evals Sm3Hash(str). func (b *builtinSM3Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() buf, err := b.bufAllocator.get() diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 6a75b504ff7d9..32390011f2513 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -1332,10 +1332,8 @@ func (n *UserSpec) EncodedPassword() (string, bool) { opt := n.AuthOpt if opt.ByAuthString { switch opt.AuthPlugin { - case mysql.AuthCachingSha2Password: - return auth.NewSha2Password(opt.AuthString), true - case mysql.AuthTiDBSM3Password: - return auth.NewSM3Password(opt.AuthString), true + case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: + return auth.NewHashPassword(opt.AuthString, opt.AuthPlugin), true case mysql.AuthSocket: return "", true default: diff --git a/parser/auth/caching_sha2.go b/parser/auth/caching_sha2.go index b7166d201ea08..055f78f90133a 100644 --- a/parser/auth/caching_sha2.go +++ b/parser/auth/caching_sha2.go @@ -38,6 +38,8 @@ import ( "errors" "fmt" "strconv" + + "github.com/pingcap/tidb/parser/mysql" ) const ( @@ -60,7 +62,13 @@ func b64From24bit(b []byte, n int, buf *bytes.Buffer) { } } -func sha256crypt(plaintext string, salt []byte, iterations int) string { +func Sha256Hash(input []byte) []byte { + res := sha256.Sum256(input) + return res[:] +} + +// 'hash' function should return an array with 32 bytes, the same as SHA-256 +func hashCrypt(plaintext string, salt []byte, iterations int, hash func([]byte) []byte) string { // Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt // 1, 2, 3 @@ -73,7 +81,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { bufB.Write([]byte(plaintext)) bufB.Write(salt) bufB.Write([]byte(plaintext)) - sumB := sha256.Sum256(bufB.Bytes()) + sumB := hash(bufB.Bytes()) bufB.Reset() // 9, 10 @@ -93,7 +101,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { } // 12 - sumA := sha256.Sum256(bufA.Bytes()) + sumA := hash(bufA.Bytes()) bufA.Reset() // 13, 14, 15 @@ -101,7 +109,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { for range []byte(plaintext) { bufDP.Write([]byte(plaintext)) } - sumDP := sha256.Sum256(bufDP.Bytes()) + sumDP := hash(bufDP.Bytes()) bufDP.Reset() // 16 @@ -119,7 +127,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { for i = 0; i < 16+int(sumA[0]); i++ { bufDS.Write(salt) } - sumDS := sha256.Sum256(bufDS.Bytes()) + sumDS := hash(bufDS.Bytes()) bufDS.Reset() // 20 @@ -134,7 +142,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { // 21 bufC := bufA - var sumC [32]byte + var sumC []byte for i = 0; i < iterations; i++ { bufC.Reset() if i&1 != 0 { @@ -153,7 +161,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { } else { bufC.Write(p) } - sumC = sha256.Sum256(bufC.Bytes()) + sumC = hash(bufC.Bytes()) sumA = sumC } @@ -180,8 +188,8 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string { return buf.String() } -// CheckShaPassword is to check if a MySQL style caching_sha2 authentication string matches a password -func CheckShaPassword(pwhash []byte, password string) (bool, error) { +// CheckHashingPassword checks if a caching_sha2_password or tidb_sm3_password authentication string matches a password +func CheckHashingPassword(pwhash []byte, password string, hash string) (bool, error) { pwhashParts := bytes.Split(pwhash, []byte("$")) if len(pwhashParts) != 4 { return false, errors.New("failed to decode hash parts") @@ -199,13 +207,19 @@ func CheckShaPassword(pwhash []byte, password string) (bool, error) { iterations = iterations * ITERATION_MULTIPLIER salt := pwhashParts[3][:SALT_LENGTH] - newHash := sha256crypt(password, salt, iterations) + var newHash string + switch hash { + case mysql.AuthCachingSha2Password: + newHash = hashCrypt(password, salt, iterations, Sha256Hash) + case mysql.AuthTiDBSM3Password: + newHash = hashCrypt(password, salt, iterations, Sm3Hash) + } return bytes.Equal(pwhash, []byte(newHash)), nil } -// NewSha2Password creates a new MySQL style caching_sha2 password hash -func NewSha2Password(pwd string) string { +// NewHashPassword creates a new password for caching_sha2_password or tidb_sm3_password +func NewHashPassword(pwd string, hash string) string { salt := make([]byte, SALT_LENGTH) rand.Read(salt) @@ -219,5 +233,12 @@ func NewSha2Password(pwd string) string { } } - return sha256crypt(pwd, salt, 5*ITERATION_MULTIPLIER) + switch hash { + case mysql.AuthCachingSha2Password: + return hashCrypt(pwd, salt, 5*ITERATION_MULTIPLIER, Sha256Hash) + case mysql.AuthTiDBSM3Password: + return hashCrypt(pwd, salt, 5*ITERATION_MULTIPLIER, Sm3Hash) + default: + return "" + } } diff --git a/parser/auth/caching_sha2_test.go b/parser/auth/caching_sha2_test.go index 51fe4dcaf692d..d7c1f8ff8dee9 100644 --- a/parser/auth/caching_sha2_test.go +++ b/parser/auth/caching_sha2_test.go @@ -17,6 +17,7 @@ import ( "encoding/hex" "testing" + "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -24,7 +25,7 @@ var foobarPwdSHA2Hash, _ = hex.DecodeString("24412430303524031A69251C34295C4B351 func TestCheckShaPasswordGood(t *testing.T) { pwd := "foobar" - r, err := CheckShaPassword(foobarPwdSHA2Hash, pwd) + r, err := CheckHashingPassword(foobarPwdSHA2Hash, pwd, mysql.AuthCachingSha2Password) require.NoError(t, err) require.True(t, r) } @@ -32,7 +33,7 @@ func TestCheckShaPasswordGood(t *testing.T) { func TestCheckShaPasswordBad(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") - r, err := CheckShaPassword(pwhash, pwd) + r, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) require.NoError(t, err) require.False(t, r) } @@ -40,30 +41,30 @@ func TestCheckShaPasswordBad(t *testing.T) { func TestCheckShaPasswordShort(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("aaaaaaaa") - _, err := CheckShaPassword(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) require.Error(t, err) } func TestCheckShaPasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") - _, err := CheckShaPassword(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) require.Error(t, err) } func TestCheckShaPasswordIterationsInvalid(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430304124031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") - _, err := CheckShaPassword(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) require.Error(t, err) } -// The output from NewSha2Password is not stable as the hash is based on the generated salt. -// This is why CheckShaPassword is used here. +// The output from NewHashPassword is not stable as the hash is based on the generated salt. +// This is why CheckHashingPassword is used here. func TestNewSha2Password(t *testing.T) { pwd := "testpwd" - pwhash := NewSha2Password(pwd) - r, err := CheckShaPassword([]byte(pwhash), pwd) + pwhash := NewHashPassword(pwd, mysql.AuthCachingSha2Password) + r, err := CheckHashingPassword([]byte(pwhash), pwd, mysql.AuthCachingSha2Password) require.NoError(t, err) require.True(t, r) @@ -76,7 +77,7 @@ func TestNewSha2Password(t *testing.T) { func BenchmarkShaPassword(b *testing.B) { for i := 0; i < b.N; i++ { - m, err := CheckShaPassword(foobarPwdSHA2Hash, "foobar") + m, err := CheckHashingPassword(foobarPwdSHA2Hash, "foobar", mysql.AuthCachingSha2Password) require.Nil(b, err) require.True(b, m) } diff --git a/parser/auth/sm3.go b/parser/auth/tidb_sm3.go similarity index 59% rename from parser/auth/sm3.go rename to parser/auth/tidb_sm3.go index 99452fbf0e941..10e96f0cfdcd3 100644 --- a/parser/auth/sm3.go +++ b/parser/auth/tidb_sm3.go @@ -14,17 +14,11 @@ package auth import ( - "bytes" - "crypto/rand" - "crypto/sha256" "encoding/binary" - "errors" - "fmt" "hash" - "strconv" ) -// The concrete SM3 Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml +// The concrete Sm3Hash Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml // This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3 // Some other references: // https://datatracker.ietf.org/doc/draft-sca-cfrg-sm3/ @@ -207,178 +201,16 @@ func (sm3 *sm3) Sum(in []byte) []byte { return out } -// NewSM3 returns a new hash.Hash computing the SM3 checksum. +// NewSM3 returns a new hash.Hash computing the Sm3Hash checksum. func NewSM3() hash.Hash { var h sm3 h.Reset() return &h } -// SM3 returns the sm3 checksum of the data. -func SM3(data []byte) []byte { +// Sm3Hash returns the sm3 checksum of the data. +func Sm3Hash(data []byte) []byte { h := NewSM3() h.Write(data) return h.Sum(nil) } - -func sm3crypt(plaintext string, salt []byte, iterations int) string { - // Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt - - // 1, 2, 3 - bufA := bytes.NewBuffer(make([]byte, 0, 4096)) - bufA.Write([]byte(plaintext)) - bufA.Write(salt) - - // 4, 5, 6, 7, 8 - bufB := bytes.NewBuffer(make([]byte, 0, 4096)) - bufB.Write([]byte(plaintext)) - bufB.Write(salt) - bufB.Write([]byte(plaintext)) - sumB := SM3(bufB.Bytes()) - bufB.Reset() - - // 9, 10 - var i int - for i = len(plaintext); i > MIXCHARS; i -= MIXCHARS { - bufA.Write(sumB[:MIXCHARS]) - } - bufA.Write(sumB[:i]) - - // 11 - for i = len(plaintext); i > 0; i >>= 1 { - if i%2 == 0 { - bufA.Write([]byte(plaintext)) - } else { - bufA.Write(sumB[:]) - } - } - - // 12 - sumA := SM3(bufA.Bytes()) - bufA.Reset() - - // 13, 14, 15 - bufDP := bufA - for range []byte(plaintext) { - bufDP.Write([]byte(plaintext)) - } - sumDP := SM3(bufDP.Bytes()) - bufDP.Reset() - - // 16 - p := make([]byte, 0, sha256.Size) - for i = len(plaintext); i > 0; i -= MIXCHARS { - if i > MIXCHARS { - p = append(p, sumDP[:]...) - } else { - p = append(p, sumDP[0:i]...) - } - } - - // 17, 18, 19 - bufDS := bufA - for i = 0; i < 16+int(sumA[0]); i++ { - bufDS.Write(salt) - } - sumDS := SM3(bufDS.Bytes()) - bufDS.Reset() - - // 20 - s := make([]byte, 0, 32) - for i = len(salt); i > 0; i -= MIXCHARS { - if i > MIXCHARS { - s = append(s, sumDS[:]...) - } else { - s = append(s, sumDS[0:i]...) - } - } - - // 21 - bufC := bufA - var sumC []byte - for i = 0; i < iterations; i++ { - bufC.Reset() - if i&1 != 0 { - bufC.Write(p) - } else { - bufC.Write(sumA[:]) - } - if i%3 != 0 { - bufC.Write(s) - } - if i%7 != 0 { - bufC.Write(p) - } - if i&1 != 0 { - bufC.Write(sumA[:]) - } else { - bufC.Write(p) - } - sumC = SM3(bufC.Bytes()) - sumA = sumC - } - - // 22 - buf := bytes.NewBuffer(make([]byte, 0, 100)) - buf.Write([]byte{'$', 'B', '$'}) - rounds := fmt.Sprintf("%03d", iterations/ITERATION_MULTIPLIER) - buf.Write([]byte(rounds)) - buf.Write([]byte{'$'}) - buf.Write(salt) - - b64From24bit([]byte{sumC[0], sumC[10], sumC[20]}, 4, buf) - b64From24bit([]byte{sumC[21], sumC[1], sumC[11]}, 4, buf) - b64From24bit([]byte{sumC[12], sumC[22], sumC[2]}, 4, buf) - b64From24bit([]byte{sumC[3], sumC[13], sumC[23]}, 4, buf) - b64From24bit([]byte{sumC[24], sumC[4], sumC[14]}, 4, buf) - b64From24bit([]byte{sumC[15], sumC[25], sumC[5]}, 4, buf) - b64From24bit([]byte{sumC[6], sumC[16], sumC[26]}, 4, buf) - b64From24bit([]byte{sumC[27], sumC[7], sumC[17]}, 4, buf) - b64From24bit([]byte{sumC[18], sumC[28], sumC[8]}, 4, buf) - b64From24bit([]byte{sumC[9], sumC[19], sumC[29]}, 4, buf) - b64From24bit([]byte{0, sumC[31], sumC[30]}, 3, buf) - - return buf.String() -} - -// CheckSM3Password checks if a sm3 authentication string matches a password -func CheckSM3Password(pwhash []byte, password string) (bool, error) { - pwhashParts := bytes.Split(pwhash, []byte("$")) - if len(pwhashParts) != 4 { - return false, errors.New("failed to decode hash parts") - } - - hashType := string(pwhashParts[1]) - if hashType != "B" { - return false, errors.New("digest type is incompatible") - } - - iterations, err := strconv.Atoi(string(pwhashParts[2])) - if err != nil { - return false, errors.New("failed to decode iterations") - } - iterations = iterations * ITERATION_MULTIPLIER - salt := pwhashParts[3][:SALT_LENGTH] - - newHash := sm3crypt(password, salt, iterations) - - return bytes.Equal(pwhash, []byte(newHash)), nil -} - -// NewSM3Password creates a new SM3 password hash -func NewSM3Password(pwd string) string { - salt := make([]byte, SALT_LENGTH) - rand.Read(salt) - - // Restrict to 7-bit to avoid multi-byte UTF-8 - for i := range salt { - salt[i] = salt[i] &^ 128 - for salt[i] == 36 || salt[i] == 0 { // '$' or NUL - newval := make([]byte, 1) - rand.Read(newval) - salt[i] = newval[0] &^ 128 - } - } - - return sm3crypt(pwd, salt, 5*ITERATION_MULTIPLIER) -} diff --git a/parser/auth/sm3_test.go b/parser/auth/tidb_sm3_test.go similarity index 73% rename from parser/auth/sm3_test.go rename to parser/auth/tidb_sm3_test.go index 4bd885c13154d..ae2d3162c4406 100644 --- a/parser/auth/sm3_test.go +++ b/parser/auth/tidb_sm3_test.go @@ -17,10 +17,11 @@ import ( "encoding/hex" "testing" + "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) -var foobarPwdSM3Hash, _ = hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") +var foobarPwdSM3Hash, _ = hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") func TestSM3(t *testing.T) { testCases := [][]string{ @@ -32,22 +33,22 @@ func TestSM3(t *testing.T) { for _, testCase := range testCases { text := testCase[0] expect, _ = hex.DecodeString(testCase[1]) - result := SM3([]byte(text)) + result := Sm3Hash([]byte(text)) require.Equal(t, expect, result) } } func TestCheckSM3PasswordGood(t *testing.T) { pwd := "foobar" - r, err := CheckSM3Password(foobarPwdSM3Hash, pwd) + r, err := CheckHashingPassword(foobarPwdSM3Hash, pwd, mysql.AuthTiDBSM3Password) require.NoError(t, err) require.True(t, r) } func TestCheckSM3PasswordBad(t *testing.T) { pwd := "not_foobar" - pwhash, _ := hex.DecodeString("24422430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") - r, err := CheckSM3Password(pwhash, pwd) + pwhash, _ := hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b6309134956387565426743446d3643446176712f6c4b63323667346e48624872776f39512e4342416a693656676f2f") + r, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) require.NoError(t, err) require.False(t, r) } @@ -55,28 +56,28 @@ func TestCheckSM3PasswordBad(t *testing.T) { func TestCheckSM3PasswordShort(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("aaaaaaaa") - _, err := CheckSM3Password(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) require.Error(t, err) } func TestCheckSM3PasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24432430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") - _, err := CheckSM3Password(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) require.Error(t, err) } func TestCheckSM3PasswordIterationsInvalid(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430304124031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") - _, err := CheckSM3Password(pwhash, pwd) + _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) require.Error(t, err) } func TestNewSM3Password(t *testing.T) { pwd := "testpwd" - pwhash := NewSM3Password(pwd) - r, err := CheckSM3Password([]byte(pwhash), pwd) + pwhash := NewHashPassword(pwd, mysql.AuthTiDBSM3Password) + r, err := CheckHashingPassword([]byte(pwhash), pwd, mysql.AuthTiDBSM3Password) require.NoError(t, err) require.True(t, r) @@ -89,7 +90,7 @@ func TestNewSM3Password(t *testing.T) { func BenchmarkSM3Password(b *testing.B) { for i := 0; i < b.N; i++ { - m, err := CheckSM3Password(foobarPwdSM3Hash, "foobar") + m, err := CheckHashingPassword(foobarPwdSM3Hash, "foobar", mysql.AuthTiDBSM3Password) require.Nil(b, err) require.True(b, m) } diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 9a7fd2a3de543..534b31ed77d57 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -186,7 +186,7 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { if len(pwd) == mysql.PWDHashLen+1 { return true } - logutil.BgLogger().Error("user password from system DB not like a mysql_native_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + logutil.BgLogger().Error("the password from the mysql.user table does not match the definition of a mysql_native_password", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) return false } @@ -194,7 +194,7 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { if len(pwd) == mysql.SHAPWDHashLen { return true } - logutil.BgLogger().Error("user password from system DB not like a caching_sha2_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + logutil.BgLogger().Error("the password from the mysql.user table does not match the definition of a caching_sha2_password", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) return false } @@ -202,7 +202,7 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { if len(pwd) == mysql.SM3PWDHashLen { return true } - logutil.BgLogger().Error("user password from system DB not like a tidb_sm3_password format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + logutil.BgLogger().Error("the password from the mysql.user table does not match the definition of a tidb_sm3_password", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) return false } @@ -210,7 +210,7 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { return true } - logutil.BgLogger().Error("user password from system DB not like a known hash format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) + logutil.BgLogger().Error("user password from the mysql.user table not like a known hash format", zap.String("user", record.User), zap.String("plugin", record.AuthPlugin), zap.Int("hash_length", len(pwd))) return false } @@ -334,8 +334,8 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse if !auth.CheckScrambledPassword(salt, hpwd, authentication) { return ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } - case mysql.AuthCachingSha2Password: - authok, err := auth.CheckShaPassword([]byte(pwd), string(authentication)) + case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: + authok, err := auth.CheckHashingPassword([]byte(pwd), string(authentication), record.AuthPlugin) if err != nil { logutil.BgLogger().Error("Failed to check caching_sha2_password", zap.Error(err)) } @@ -343,15 +343,6 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse if !authok { return ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } - case mysql.AuthTiDBSM3Password: - authOK, err := auth.CheckSM3Password([]byte(pwd), string(authentication)) - if err != nil { - logutil.BgLogger().Error("Failed to check tidb_sm3_password", zap.Error(err)) - } - - if !authOK { - return errAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } case mysql.AuthSocket: if string(authentication) != authUser && string(authentication) != pwd { logutil.BgLogger().Error("Failed socket auth", zap.String("authUser", authUser), From ffa4176176452b06173ae08848641ac868cd790a Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 6 Sep 2022 15:52:29 +0800 Subject: [PATCH 27/29] Fix --- parser/auth/caching_sha2.go | 1 + 1 file changed, 1 insertion(+) diff --git a/parser/auth/caching_sha2.go b/parser/auth/caching_sha2.go index 055f78f90133a..a31d466d9b09e 100644 --- a/parser/auth/caching_sha2.go +++ b/parser/auth/caching_sha2.go @@ -62,6 +62,7 @@ func b64From24bit(b []byte, n int, buf *bytes.Buffer) { } } +// Sha256Hash is an util function to calculate sha256 hash. func Sha256Hash(input []byte) []byte { res := sha256.Sum256(input) return res[:] From eda8c41e4e7c1148b7364e87da6f4c64053fb438 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 7 Sep 2022 15:33:17 +0800 Subject: [PATCH 28/29] Improve compatibility --- server/conn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/conn.go b/server/conn.go index 5fac5f199898f..14368727306dc 100644 --- a/server/conn.go +++ b/server/conn.go @@ -923,6 +923,11 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon // method to match the one configured for that specific user. if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { if resp.Capability&mysql.ClientPluginAuth > 0 { + // For compatibility, since most mysql client doesn't support 'tidb_sm3_password', + // they can connect to TiDB using a `tidb_sm3_password` user with a 'caching_sha2_password' plugin. + if userplugin == mysql.AuthTiDBSM3Password { + userplugin = mysql.AuthCachingSha2Password + } authData, err := cc.authSwitchRequest(ctx, userplugin) if err != nil { return nil, err From 0012cc4420d38bbfdf50442b1de75da1f59a8adc Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 7 Sep 2022 16:25:28 +0800 Subject: [PATCH 29/29] Fix UT --- server/conn_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/conn_test.go b/server/conn_test.go index 3c6edcf30bba7..3acc67c7ff4f5 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1288,7 +1288,7 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // 8.0 or newer client trying to authenticate with caching_sha2_password @@ -1311,7 +1311,7 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // client trying to authenticate with tidb_sm3_password @@ -1334,7 +1334,7 @@ func TestHandleAuthPlugin(t *testing.T) { } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) - require.Equal(t, []byte(mysql.AuthTiDBSM3Password), resp.Auth) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) // MySQL 5.1 or older client, without authplugin support