Skip to content

Commit

Permalink
Merge pull request #9264 from mattrobenolt/ref-auth-server
Browse files Browse the repository at this point in the history
mysql: Pass mysql.Conn through {Hash,PlainText,Caching}Storage interfaces
  • Loading branch information
deepthi authored Nov 30, 2021
2 parents 6a695ec + 7c247f4 commit 6d8de8e
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 21 deletions.
17 changes: 8 additions & 9 deletions go/mysql/auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"crypto/sha1"
"crypto/sha256"
"crypto/subtle"
"crypto/x509"
"encoding/hex"
"net"
"sync"
Expand Down Expand Up @@ -132,7 +131,7 @@ const (
// such a hash based on the salt and auth response provided here after retrieving
// the hashed password from the storage.
type HashStorage interface {
UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error)
UserEntryWithHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error)
}

// PlainTextStorage describes an object that is suitable to retrieve user information
Expand All @@ -146,7 +145,7 @@ type HashStorage interface {
// When comparing plain text passwords directly, please ensure to use `subtle.ConstantTimeCompare`
// to prevent timing based attacks on the password.
type PlainTextStorage interface {
UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error)
UserEntryWithPassword(conn *Conn, user string, password string, remoteAddr net.Addr) (Getter, error)
}

// CachingStorage describes an object that is suitable to retrieve user information
Expand All @@ -159,7 +158,7 @@ type PlainTextStorage interface {
// such a hash based on the salt and auth response provided here after retrieving
// the hashed password from the cache.
type CachingStorage interface {
UserEntryWithCacheHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error)
UserEntryWithCacheHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error)
}

// NewMysqlNativeAuthMethod will create a new AuthMethod that implements the
Expand Down Expand Up @@ -432,7 +431,7 @@ func (n *mysqlNativePasswordAuthMethod) HandleAuthPluginData(conn *Conn, user st
}

salt := serverAuthPluginData[:len(serverAuthPluginData)-1]
return n.storage.UserEntryWithHash(conn.GetTLSClientCerts(), salt, user, clientAuthPluginData, remoteAddr)
return n.storage.UserEntryWithHash(conn, salt, user, clientAuthPluginData, remoteAddr)
}

type mysqlClearAuthMethod struct {
Expand All @@ -457,7 +456,7 @@ func (n *mysqlClearAuthMethod) AllowClearTextWithoutTLS() bool {
}

func (n *mysqlClearAuthMethod) HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) {
return n.storage.UserEntryWithPassword(conn.GetTLSClientCerts(), user, string(clientAuthPluginData[:len(clientAuthPluginData)-1]), remoteAddr)
return n.storage.UserEntryWithPassword(conn, user, string(clientAuthPluginData[:len(clientAuthPluginData)-1]), remoteAddr)
}

type mysqlDialogAuthMethod struct {
Expand All @@ -482,7 +481,7 @@ func (n *mysqlDialogAuthMethod) AuthPluginData() ([]byte, error) {
}

func (n *mysqlDialogAuthMethod) HandleAuthPluginData(conn *Conn, user string, serverAuthPluginData []byte, clientAuthPluginData []byte, remoteAddr net.Addr) (Getter, error) {
return n.storage.UserEntryWithPassword(conn.GetTLSClientCerts(), user, string(clientAuthPluginData[:len(clientAuthPluginData)-1]), remoteAddr)
return n.storage.UserEntryWithPassword(conn, user, string(clientAuthPluginData[:len(clientAuthPluginData)-1]), remoteAddr)
}

func (n *mysqlDialogAuthMethod) AllowClearTextWithoutTLS() bool {
Expand Down Expand Up @@ -524,7 +523,7 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string,
}

salt := serverAuthPluginData[:len(serverAuthPluginData)-1]
result, cacheState, err := n.cache.UserEntryWithCacheHash(c.GetTLSClientCerts(), salt, user, clientAuthPluginData, remoteAddr)
result, cacheState, err := n.cache.UserEntryWithCacheHash(c, salt, user, clientAuthPluginData, remoteAddr)

if err != nil {
return nil, err
Expand Down Expand Up @@ -560,7 +559,7 @@ func (n *mysqlCachingSha2AuthMethod) HandleAuthPluginData(c *Conn, user string,
return nil, err
}

return n.storage.UserEntryWithPassword(c.GetTLSClientCerts(), user, password, remoteAddr)
return n.storage.UserEntryWithPassword(c, user, password, remoteAddr)
default:
// Somehow someone returned an unknown state, let's error with access denied.
return nil, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package mysql

import (
"crypto/x509"
"flag"
"fmt"
"net"
Expand Down Expand Up @@ -85,7 +84,8 @@ func (asl *AuthServerClientCert) HandleUser(user string) bool {
}

// UserEntryWithPassword is part of the PlaintextStorage interface
func (asl *AuthServerClientCert) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error) {
func (asl *AuthServerClientCert) UserEntryWithPassword(conn *Conn, user string, password string, remoteAddr net.Addr) (Getter, error) {
userCerts := conn.GetTLSClientCerts()
if len(userCerts) == 0 {
return nil, fmt.Errorf("no client certs for connection")
}
Expand Down
3 changes: 1 addition & 2 deletions go/mysql/auth_server_none.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package mysql

import (
"crypto/x509"
"net"

querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -51,7 +50,7 @@ func (a *AuthServerNone) HandleUser(user string) bool {

// UserEntryWithHash validates the user if it exists and returns the information.
// Always accepts any user.
func (a *AuthServerNone) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
func (a *AuthServerNone) UserEntryWithHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
return &NoneGetter{}, nil
}

Expand Down
7 changes: 3 additions & 4 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package mysql
import (
"bytes"
"crypto/subtle"
"crypto/x509"
"encoding/json"
"flag"
"net"
Expand Down Expand Up @@ -161,7 +160,7 @@ func (a *AuthServerStatic) HandleUser(user string) bool {

// UserEntryWithPassword implements password lookup based on a plain
// text password that is negotiated with the client.
func (a *AuthServerStatic) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (Getter, error) {
func (a *AuthServerStatic) UserEntryWithPassword(conn *Conn, user string, password string, remoteAddr net.Addr) (Getter, error) {
a.mu.Lock()
entries, ok := a.entries[user]
a.mu.Unlock()
Expand All @@ -181,7 +180,7 @@ func (a *AuthServerStatic) UserEntryWithPassword(userCerts []*x509.Certificate,

// UserEntryWithHash implements password lookup based on a
// mysql_native_password hash that is negotiated with the client.
func (a *AuthServerStatic) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
func (a *AuthServerStatic) UserEntryWithHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
a.mu.Lock()
entries, ok := a.entries[user]
a.mu.Unlock()
Expand Down Expand Up @@ -214,7 +213,7 @@ func (a *AuthServerStatic) UserEntryWithHash(userCerts []*x509.Certificate, salt

// UserEntryWithCacheHash implements password lookup based on a
// caching_sha2_password hash that is negotiated with the client.
func (a *AuthServerStatic) UserEntryWithCacheHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) {
func (a *AuthServerStatic) UserEntryWithCacheHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) {
a.mu.Lock()
entries, ok := a.entries[user]
a.mu.Unlock()
Expand Down
3 changes: 1 addition & 2 deletions go/mysql/ldapauthserver/auth_server_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package ldapauthserver

import (
"crypto/x509"
"encoding/json"
"flag"
"fmt"
Expand Down Expand Up @@ -118,7 +117,7 @@ func (asl *AuthServerLdap) HandleUser(user string) bool {

// UserEntryWithPassword is part of the PlaintextStorage interface
// and called after the password is sent by the client.
func (asl *AuthServerLdap) UserEntryWithPassword(userCerts []*x509.Certificate, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) {
func (asl *AuthServerLdap) UserEntryWithPassword(conn *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) {
return asl.validate(user, password)
}

Expand Down
3 changes: 1 addition & 2 deletions go/mysql/vault/auth_server_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package vault

import (
"crypto/subtle"
"crypto/x509"
"flag"
"fmt"
"net"
Expand Down Expand Up @@ -166,7 +165,7 @@ func (a *AuthServerVault) HandleUser(user string) bool {
}

// UserEntryWithHash is called when mysql_native_password is used.
func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
func (a *AuthServerVault) UserEntryWithHash(conn *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
a.mu.Lock()
userEntries, ok := a.entries[user]
a.mu.Unlock()
Expand Down

0 comments on commit 6d8de8e

Please sign in to comment.