Skip to content

Commit

Permalink
mysql/client: Add support for caching_sha2_password plugin
Browse files Browse the repository at this point in the history
caching_sha2_password plugin, unlike mysql_native_password, is a multi
step process. Client hashes the password using SHA2 algorithm and sends
it to the server. Server responds with either an AUTH_MORE_DATA (0x01)
packet or an Error packet.

Error packet is sent when authentication fails during "fast" auth (more
on this below).

The second byte of AUTH_MORE_DATA packet will either be 0x03 or 0x04.

0x03 represents a successful "fast" auth meaning that the server has a
cached hash of the password for the user and it matches the hash sent by
the client. Server will send an OK packet next.

Server sends 0x04 when the hash of the password is not yet cached. In
this case, client has to do a "full" authentication by sending un-hashed
password. If the client is connected using SSL or a Unix socket, client
can write the password in clear text. If this is not the case, client
has to request a public key from the server to encrypt the password.
Client should obfuscate the password using xor operation before
encrypting it with public key and sending it to the server. Server will
respond with an OK packet if autentication is successful or Error packet
otherwise.

Signed-off-by: Vamsi Atluri <vamc19@gmail.com>
  • Loading branch information
vamc19 committed Sep 14, 2020
1 parent e91548e commit 1302db0
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 65 deletions.
58 changes: 56 additions & 2 deletions go/mysql/auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package mysql
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"net"
"strings"
Expand Down Expand Up @@ -117,8 +119,8 @@ func NewSalt() ([]byte, error) {
return salt, nil
}

// ScramblePassword computes the hash of the password using 4.1+ method.
func ScramblePassword(salt, password []byte) []byte {
// ScrambleMysqlNativePassword computes the hash of the password using 4.1+ method.
func ScrambleMysqlNativePassword(salt, password []byte) []byte {
if len(password) == 0 {
return nil
}
Expand Down Expand Up @@ -189,6 +191,58 @@ func isPassScrambleMysqlNativePassword(reply, salt []byte, mysqlNativePassword s
return bytes.Equal(candidateHash2, hash)
}

// ScrambleCachingSha2Password computes the hash of the password using SHA256 as required by
// caching_sha2_password plugin for "fast" authentication
func ScrambleCachingSha2Password(salt []byte, password []byte) []byte {
if len(password) == 0 {
return nil
}

// stage1Hash = SHA256(password)
crypt := sha256.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)

// scrambleHash = SHA256(SHA256(stage1Hash) + salt)
crypt.Reset()
crypt.Write(stage1)
innerHash := crypt.Sum(nil)

crypt.Reset()
crypt.Write(innerHash)
crypt.Write(salt)
scramble := crypt.Sum(nil)

// token = stage1Hash XOR scrambleHash
for i := range stage1 {
stage1[i] ^= scramble[i]
}

return stage1
}

// EncryptPasswordWithPublicKey obfuscates the password and encrypts it with server's public key as required by
// caching_sha2_password plugin for "full" authentication
func EncryptPasswordWithPublicKey(salt []byte, password []byte, pub *rsa.PublicKey) ([]byte, error) {
if len(password) == 0 {
return nil, nil
}

buffer := make([]byte, len(password)+1)
copy(buffer, password)
for i := range buffer {
buffer[i] ^= salt[i%len(salt)]
}

sha1Hash := sha1.New()
enc, err := rsa.EncryptOAEP(sha1Hash, rand.Reader, pub, buffer, nil)
if err != nil {
return nil, err
}

return enc, nil
}

// Constants for the dialog plugin.
const (
mysqlDialogMessage = "Enter password: "
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (a *AuthServerStatic) ValidateHash(salt []byte, user string, authResponse [
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
} else {
computedAuthResponse := ScramblePassword(salt, []byte(entry.Password))
computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) {
return &StaticUserData{entry.UserData, entry.Groups}, nil
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestValidateHashGetter(t *testing.T) {
t.Fatalf("error generating salt: %v", err)
}

scrambled := ScramblePassword(salt, []byte("password"))
scrambled := ScrambleMysqlNativePassword(salt, []byte("password"))
getter, err := auth.ValidateHash(salt, "mysql_user", scrambled, addr)
if err != nil {
t.Fatalf("error validating password: %v", err)
Expand Down Expand Up @@ -274,7 +274,7 @@ func TestStaticPasswords(t *testing.T) {
t.Fatalf("error generating salt: %v", err)
}

scrambled := ScramblePassword(salt, []byte(c.password))
scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password))
_, err = auth.ValidateHash(salt, c.user, scrambled, addr)

if c.success {
Expand Down
206 changes: 150 additions & 56 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License.
package mysql

import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"strconv"
Expand Down Expand Up @@ -234,6 +237,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
return err
}
c.fillFlavor(params)
c.salt = salt

// Sanity check.
if capabilities&CapabilityClientProtocol41 == 0 {
Expand Down Expand Up @@ -290,7 +294,12 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
}

// Password encryption.
scrambledPassword := ScramblePassword(salt, []byte(params.Pass))
var scrambledPassword []byte
if c.authPluginName == CachingSha2Password {
scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass))
} else {
scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass))
}

// Build and send our handshake response 41.
// Note this one will never have SSL flag on.
Expand All @@ -299,54 +308,8 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
}

// Read the server response.
response, err := c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
switch response[0] {
case OKPacket:
// OK packet, we are authenticated. Save the user, keep going.
c.User = params.Uname
case AuthSwitchRequestPacket:
// Server is asking to use a different auth method. We
// only support cleartext plugin.
pluginName, salt, err := parseAuthSwitchRequest(response)
if err != nil {
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err)
}

if pluginName == MysqlClearPassword {
// Write the cleartext password packet.
if err := c.writeClearTextPassword(params); err != nil {
return err
}
} else if pluginName == MysqlNativePassword {
// Write the mysql_native_password packet.
if err := c.writeMysqlNativePassword(params, salt); err != nil {
return err
}
} else {
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", pluginName)
}

// Wait for OK packet.
response, err = c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
switch response[0] {
case OKPacket:
// OK packet, we are authenticated. Save the user, keep going.
c.User = params.Uname
case ErrPacket:
return ParseErrorPacket(response)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response)
}
case ErrPacket:
return ParseErrorPacket(response)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response)
if err := c.handleAuthResponse(params); err != nil {
return err
}

// If the server didn't support DbName in its handshake, set
Expand Down Expand Up @@ -504,7 +467,7 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error)
// 5.6.2 that don't have a null terminated string.
authPluginName = string(data[pos : len(data)-1])
}
c.DefaultAuthPluginName = authPluginName
c.authPluginName = authPluginName
}

return capabilities, authPluginData, nil
Expand Down Expand Up @@ -588,7 +551,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
lenNullString(params.Uname) +
// length of scrambled password is handled below.
len(scrambledPassword) +
21 + // "mysql_native_password" string.
len(c.authPluginName) +
1 // terminating zero.

// Add the DB name if the server supports it.
Expand Down Expand Up @@ -637,7 +600,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
}

// Assume native client during response
pos = writeNullString(data, pos, MysqlNativePassword)
pos = writeNullString(data, pos, c.authPluginName)

// Sanity-check the length.
if pos != len(data) {
Expand All @@ -650,6 +613,110 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
return nil
}

// handleAuthResponse parses server's response after client sends the password for authentication
// and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket.
func (c *Conn) handleAuthResponse(params *ConnParams) error {
response, err := c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}

switch response[0] {
case OKPacket:
// OK packet, we are authenticated. Save the user, keep going.
c.User = params.Uname
case AuthSwitchRequestPacket:
// Server is asking to use a different auth method
if err = c.handleAuthSwitchPacket(params, response); err != nil {
return err
}
case AuthMoreDataPacket:
// Server is requesting more data - maybe un-scrambled password
if err := c.handleAuthMoreDataPacket(response[1], params); err != nil {
return err
}
case ErrPacket:
return ParseErrorPacket(response)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response)
}

return nil
}

// handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication
func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error {
var err error
var salt []byte
c.authPluginName, salt, err = parseAuthSwitchRequest(response)
if err != nil {
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err)
}
if salt != nil {
c.salt = salt
}
switch c.authPluginName {
case MysqlClearPassword:
if err := c.writeClearTextPassword(params); err != nil {
return err
}
case MysqlNativePassword:
scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass))
if err := c.writeScrambledPassword(scrambledPassword); err != nil {
return err
}
case CachingSha2Password:
scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass))
if err := c.writeScrambledPassword(scrambledPassword); err != nil {
return err
}
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName)
}

// The response could be an OKPacket, AuthMoreDataPacket or ErrPacket
return c.handleAuthResponse(params)
}

// handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the
// server if requested
func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error {
switch data {
case CachingSha2FastAuth:
// User credentials are verified using the cache ("Fast" path).
// Next packet should be an OKPacket
return c.handleAuthResponse(params)
case CachingSha2FullAuth:
// User credentials are not cached, we have to exchange full password.
if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" {
// If we are using an SSL connection or Unix socket, write clear text password
if err := c.writeClearTextPassword(params); err != nil {
return err
}
} else {
// If we are not using an SSL connection or Unix socket, we have to fetch a public key
// from the server to encrypt password
pub, err := c.requestPublicKey()
if err != nil {
return err
}
// Encrypt password with public key
enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub)
if err != nil {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error encrypting password with public key: %v", err)
}
// Write encrypted password
if err := c.writeScrambledPassword(enc); err != nil {
return err
}
}
// Next packet should either be an OKPacket or ErrPacket
return c.handleAuthResponse(params)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data)
}
}

func parseAuthSwitchRequest(data []byte) (string, []byte, error) {
pos := 1
pluginName, pos, ok := readNullString(data, pos)
Expand All @@ -665,6 +732,34 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) {
return pluginName, salt, nil
}

// requestPublicKey requests a public key from the server
func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) {
// get public key from server
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = 0x02
if err := c.writeEphemeralPacket(); err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "error sending public key request packet: %v", err)
}

response, err := c.readPacket()
if err != nil {
return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}

// Server should respond with a AuthMoreDataPacket containing the public key
if response[0] != AuthMoreDataPacket {
return nil, ParseErrorPacket(response)
}

block, _ := pem.Decode(response[1:])
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to parse public key from server: %v", err)
}

return pub.(*rsa.PublicKey), nil
}

// writeClearTextPassword writes the clear text password.
// Returns a SQLError.
func (c *Conn) writeClearTextPassword(params *ConnParams) error {
Expand All @@ -678,15 +773,14 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error {
return c.writeEphemeralPacket()
}

// writeMysqlNativePassword writes the encrypted mysql_native_password format
// writeScrambledPassword writes the encrypted mysql_native_password format
// Returns a SQLError.
func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error {
scrambledPassword := ScramblePassword(salt, []byte(params.Pass))
func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error {
data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword))
pos += copy(data[pos:], scrambledPassword)
// Sanity check.
if pos != len(data) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building MysqlNativePassword packet: got %v bytes expected %v", pos, len(data))
return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data))
}
return c.writeEphemeralPacket()
}
Loading

0 comments on commit 1302db0

Please sign in to comment.