Skip to content

Commit

Permalink
server: unix socket should verify user's authentication (#8381)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackysp authored Nov 22, 2018
1 parent f085f4f commit e69aa27
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 10 deletions.
16 changes: 9 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
Expand Down Expand Up @@ -391,16 +392,17 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error {
if err != nil {
return errors.Trace(err)
}
if !cc.server.skipAuth() {
// Do Auth.
host := variable.DefHostname
if !cc.server.isUnixSocket() {
addr := cc.bufReadConn.RemoteAddr().String()
host, _, err1 := net.SplitHostPort(addr)
if err1 != nil {
// Do Auth.
host, _, err = net.SplitHostPort(addr)
if err != nil {
return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, addr, "YES"))
}
if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) {
return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES"))
}
}
if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) {
return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES"))
}
if cc.dbname != "" {
err = cc.useDB(context.Background(), cc.dbname)
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *Server) newConn(conn net.Conn) *clientConn {
return cc
}

func (s *Server) skipAuth() bool {
func (s *Server) isUnixSocket() bool {
return s.cfg.Socket != ""
}

Expand Down
3 changes: 3 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ func (dbt *DBTest) mustQueryRows(query string, args ...interface{}) {

func runTestRegression(c *C, overrider configOverrider, dbName string) {
runTestsOnNewDB(c, overrider, dbName, func(dbt *DBTest) {
// Show the user
dbt.mustExec("select user()")

// Create Table
dbt.mustExec("CREATE TABLE test (val TINYINT)")

Expand Down
7 changes: 5 additions & 2 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1100,12 +1100,15 @@ func (s *session) GetSessionVars() *variable.SessionVars {
func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool {
pm := privilege.GetPrivilegeManager(s)

// Check IP.
// Check IP or localhost.
var success bool
user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt)
if success {
s.sessionVars.User = user
return true
} else if user.Hostname == variable.DefHostname {
log.Errorf("User connection verification failed %s", user)
return false
}

// Check Hostname.
Expand All @@ -1128,7 +1131,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by

func getHostByIP(ip string) []string {
if ip == "127.0.0.1" {
return []string{"localhost"}
return []string{variable.DefHostname}
}
addrs, err := net.LookupAddr(ip)
terror.Log(errors.Trace(err))
Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/tidb_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ const (

// Default TiDB system variable values.
const (
DefHostname = "localhost"
DefIndexLookupConcurrency = 4
DefIndexLookupJoinConcurrency = 4
DefIndexSerialScanConcurrency = 1
Expand Down

0 comments on commit e69aa27

Please sign in to comment.