diff --git a/go/mysql/auth_server.go b/go/mysql/auth_server.go index 5df328e1896..b568537e16e 100644 --- a/go/mysql/auth_server.go +++ b/go/mysql/auth_server.go @@ -19,7 +19,9 @@ package mysql import ( "bytes" "crypto/rand" + "crypto/rsa" "crypto/sha1" + "crypto/sha256" "encoding/hex" "net" "strings" @@ -117,8 +119,8 @@ func NewSalt() ([]byte, error) { return salt, nil } -// ScramblePassword computes the hash of the password using 4.1+ method. -func ScramblePassword(salt, password []byte) []byte { +// ScrambleMysqlNativePassword computes the hash of the password using 4.1+ method. +func ScrambleMysqlNativePassword(salt, password []byte) []byte { if len(password) == 0 { return nil } @@ -189,6 +191,58 @@ func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword s return bytes.Equal(candidateHash2, hash) } +// ScrambleCachingSha2Password computes the hash of the password using SHA256 as required by +// caching_sha2_password plugin for "fast" authentication +func ScrambleCachingSha2Password(salt []byte, password []byte) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA256(password) + crypt := sha256.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA256(SHA256(stage1Hash) + salt) + crypt.Reset() + crypt.Write(stage1) + innerHash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(innerHash) + crypt.Write(salt) + scramble := crypt.Sum(nil) + + // token = stage1Hash XOR scrambleHash + for i := range stage1 { + stage1[i] ^= scramble[i] + } + + return stage1 +} + +// EncryptPasswordWithPublicKey obfuscates the password and encrypts it with server's public key as required by +// caching_sha2_password plugin for "full" authentication +func EncryptPasswordWithPublicKey(salt []byte, password []byte, pub *rsa.PublicKey) ([]byte, error) { + if len(password) == 0 { + return nil, nil + } + + buffer := make([]byte, len(password)+1) + copy(buffer, password) + for i := range buffer { + buffer[i] ^= salt[i%len(salt)] + } + + sha1Hash := sha1.New() + enc, err := rsa.EncryptOAEP(sha1Hash, rand.Reader, pub, buffer, nil) + if err != nil { + return nil, err + } + + return enc, nil +} + // Constants for the dialog plugin. const ( mysqlDialogMessage = "Enter password: " diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index acb90678f54..ef95febb292 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -244,7 +244,7 @@ func (a *AuthServerStatic) ValidateHash(salt []byte, user string, authResponse [ return &StaticUserData{entry.UserData, entry.Groups}, nil } } else { - computedAuthResponse := ScramblePassword(salt, []byte(entry.Password)) + computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password)) // Validate the password. if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) { return &StaticUserData{entry.UserData, entry.Groups}, nil diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 2586da2f60b..9bb3de197f8 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -92,7 +92,7 @@ func TestValidateHashGetter(t *testing.T) { t.Fatalf("error generating salt: %v", err) } - scrambled := ScramblePassword(salt, []byte("password")) + scrambled := ScrambleMysqlNativePassword(salt, []byte("password")) getter, err := auth.ValidateHash(salt, "mysql_user", scrambled, addr) if err != nil { t.Fatalf("error validating password: %v", err) @@ -274,7 +274,7 @@ func TestStaticPasswords(t *testing.T) { t.Fatalf("error generating salt: %v", err) } - scrambled := ScramblePassword(salt, []byte(c.password)) + scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password)) _, err = auth.ValidateHash(salt, c.user, scrambled, addr) if c.success { diff --git a/go/mysql/client.go b/go/mysql/client.go index cb03c42b155..645f8d60514 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -17,7 +17,10 @@ limitations under the License. package mysql import ( + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" "net" "strconv" @@ -234,6 +237,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { return err } c.fillFlavor(params) + c.salt = salt // Sanity check. if capabilities&CapabilityClientProtocol41 == 0 { @@ -290,7 +294,12 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { } // Password encryption. - scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) + var scrambledPassword []byte + if c.authPluginName == CachingSha2Password { + scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass)) + } else { + scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass)) + } // Build and send our handshake response 41. // Note this one will never have SSL flag on. @@ -299,54 +308,8 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { } // Read the server response. - response, err := c.readPacket() - if err != nil { - return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) - } - switch response[0] { - case OKPacket: - // OK packet, we are authenticated. Save the user, keep going. - c.User = params.Uname - case AuthSwitchRequestPacket: - // Server is asking to use a different auth method. We - // only support cleartext plugin. - pluginName, salt, err := parseAuthSwitchRequest(response) - if err != nil { - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) - } - - if pluginName == MysqlClearPassword { - // Write the cleartext password packet. - if err := c.writeClearTextPassword(params); err != nil { - return err - } - } else if pluginName == MysqlNativePassword { - // Write the mysql_native_password packet. - if err := c.writeMysqlNativePassword(params, salt); err != nil { - return err - } - } else { - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", pluginName) - } - - // Wait for OK packet. - response, err = c.readPacket() - if err != nil { - return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) - } - switch response[0] { - case OKPacket: - // OK packet, we are authenticated. Save the user, keep going. - c.User = params.Uname - case ErrPacket: - return ParseErrorPacket(response) - default: - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) - } - case ErrPacket: - return ParseErrorPacket(response) - default: - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) + if err := c.handleAuthResponse(params); err != nil { + return err } // If the server didn't support DbName in its handshake, set @@ -504,7 +467,7 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) // 5.6.2 that don't have a null terminated string. authPluginName = string(data[pos : len(data)-1]) } - c.DefaultAuthPluginName = authPluginName + c.authPluginName = authPluginName } return capabilities, authPluginData, nil @@ -588,7 +551,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ lenNullString(params.Uname) + // length of scrambled password is handled below. len(scrambledPassword) + - 21 + // "mysql_native_password" string. + len(c.authPluginName) + 1 // terminating zero. // Add the DB name if the server supports it. @@ -637,7 +600,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ } // Assume native client during response - pos = writeNullString(data, pos, MysqlNativePassword) + pos = writeNullString(data, pos, c.authPluginName) // Sanity-check the length. if pos != len(data) { @@ -650,6 +613,110 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ return nil } +// handleAuthResponse parses server's response after client sends the password for authentication +// and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket. +func (c *Conn) handleAuthResponse(params *ConnParams) error { + response, err := c.readPacket() + if err != nil { + return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + } + + switch response[0] { + case OKPacket: + // OK packet, we are authenticated. Save the user, keep going. + c.User = params.Uname + case AuthSwitchRequestPacket: + // Server is asking to use a different auth method + if err = c.handleAuthSwitchPacket(params, response); err != nil { + return err + } + case AuthMoreDataPacket: + // Server is requesting more data - maybe un-scrambled password + if err := c.handleAuthMoreDataPacket(response[1], params); err != nil { + return err + } + case ErrPacket: + return ParseErrorPacket(response) + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response) + } + + return nil +} + +// handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication +func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error { + var err error + var salt []byte + c.authPluginName, salt, err = parseAuthSwitchRequest(response) + if err != nil { + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) + } + if salt != nil { + c.salt = salt + } + switch c.authPluginName { + case MysqlClearPassword: + if err := c.writeClearTextPassword(params); err != nil { + return err + } + case MysqlNativePassword: + scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass)) + if err := c.writeScrambledPassword(scrambledPassword); err != nil { + return err + } + case CachingSha2Password: + scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass)) + if err := c.writeScrambledPassword(scrambledPassword); err != nil { + return err + } + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName) + } + + // The response could be an OKPacket, AuthMoreDataPacket or ErrPacket + return c.handleAuthResponse(params) +} + +// handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the +// server if requested +func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error { + switch data { + case CachingSha2FastAuth: + // User credentials are verified using the cache ("Fast" path). + // Next packet should be an OKPacket + return c.handleAuthResponse(params) + case CachingSha2FullAuth: + // User credentials are not cached, we have to exchange full password. + if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" { + // If we are using an SSL connection or Unix socket, write clear text password + if err := c.writeClearTextPassword(params); err != nil { + return err + } + } else { + // If we are not using an SSL connection or Unix socket, we have to fetch a public key + // from the server to encrypt password + pub, err := c.requestPublicKey() + if err != nil { + return err + } + // Encrypt password with public key + enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub) + if err != nil { + return vterrors.Errorf(vtrpc.Code_INTERNAL, "error encrypting password with public key: %v", err) + } + // Write encrypted password + if err := c.writeScrambledPassword(enc); err != nil { + return err + } + } + // Next packet should either be an OKPacket or ErrPacket + return c.handleAuthResponse(params) + default: + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data) + } +} + func parseAuthSwitchRequest(data []byte) (string, []byte, error) { pos := 1 pluginName, pos, ok := readNullString(data, pos) @@ -665,6 +732,34 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) { return pluginName, salt, nil } +// requestPublicKey requests a public key from the server +func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) { + // get public key from server + data, pos := c.startEphemeralPacketWithHeader(1) + data[pos] = 0x02 + if err := c.writeEphemeralPacket(); err != nil { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error sending public key request packet: %v", err) + } + + response, err := c.readPacket() + if err != nil { + return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + } + + // Server should respond with a AuthMoreDataPacket containing the public key + if response[0] != AuthMoreDataPacket { + return nil, ParseErrorPacket(response) + } + + block, _ := pem.Decode(response[1:]) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to parse public key from server: %v", err) + } + + return pub.(*rsa.PublicKey), nil +} + // writeClearTextPassword writes the clear text password. // Returns a SQLError. func (c *Conn) writeClearTextPassword(params *ConnParams) error { @@ -678,15 +773,14 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error { return c.writeEphemeralPacket() } -// writeMysqlNativePassword writes the encrypted mysql_native_password format +// writeScrambledPassword writes the encrypted mysql_native_password format // Returns a SQLError. -func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error { - scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) +func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error { data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword)) pos += copy(data[pos:], scrambledPassword) // Sanity check. if pos != len(data) { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building MysqlNativePassword packet: got %v bytes expected %v", pos, len(data)) + return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data)) } return c.writeEphemeralPacket() } diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 4df8348bdb0..3f5132d0a2a 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -105,10 +105,6 @@ type Conn struct { // and CapabilityClientFoundRows. Capabilities uint32 - // DefaultAuthPluginName is the name of server's default authentication plugin. - // It is set during the initial handshake. - DefaultAuthPluginName string - // CharacterSet is the character set used by the other side of the // connection. // It is set during the initial handshake. @@ -123,6 +119,13 @@ type Conn struct { // It is set during the initial handshake. UserData Getter + // salt is sent by the server during initial handshake to be used for authentication + salt []byte + + // authPluginName is the name of server's authentication plugin. + // It is set during the initial handshake. + authPluginName string + // schemaName is the default database name to use. It is set // during handshake, and by ComInitDb packets. Both client and // servers maintain it. This member is private because it's diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 50d6bca7f3d..d43f6e20343 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -34,6 +34,9 @@ const ( // MysqlClearPassword transmits the password in the clear. MysqlClearPassword = "mysql_clear_password" + // CachingSha2Password uses a salt and transmits a SHA256 hash on the wire. + CachingSha2Password = "caching_sha2_password" + // MysqlDialog uses the dialog plugin on the client side. // It transmits data in the clear. MysqlDialog = "dialog" @@ -141,12 +144,21 @@ const ( // ComQuit is COM_QUIT. ComQuit = 0x01 + // AuthMoreDataPacket is sent when + AuthMoreDataPacket = 0x01 + // ComInitDB is COM_INIT_DB. ComInitDB = 0x02 // ComQuery is COM_QUERY. ComQuery = 0x03 + // CachingSha2FastAuth is sent before OKPacket when server authenticates using cache + CachingSha2FastAuth = 0x03 + + // CachingSha2FullAuth is sent when server requests un-scrambled password to authenticate + CachingSha2FullAuth = 0x04 + // ComPing is COM_PING. ComPing = 0x0e