diff --git a/pkg/proxy/driver/domain.go b/pkg/proxy/driver/domain.go index c7c9b8c6..b988e5cb 100644 --- a/pkg/proxy/driver/domain.go +++ b/pkg/proxy/driver/domain.go @@ -10,6 +10,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/siddontang/go-mysql/mysql" + wauth "github.com/tidb-incubator/weir/pkg/util/auth" ) type NamespaceManager interface { @@ -27,6 +28,7 @@ type Namespace interface { DescConnCount() GetBreaker() (Breaker, error) GetRateLimiter() RateLimiter + GetRouter() Router } type Breaker interface { @@ -45,6 +47,12 @@ type RateLimiter interface { Limit(ctx context.Context, key string) error } +type Router interface { + SetAddresses([]string) + Route() (string, error) + AddConnOnAddr(string, int) +} + type PooledBackendConn interface { // PutBack put conn back to pool PutBack() @@ -95,9 +103,10 @@ type ClientConnection interface { } type BackendConnManager interface { - SetAuthInfo(username string, authData []byte) + SetAuthInfo(authInfo *wauth.AuthInfo) Connect(address string) error Query(ctx context.Context, sql string) (*mysql.Result, error) + Close() error } // QueryCtx is the interface to execute command. @@ -134,4 +143,5 @@ type QueryCtx interface { type IDriver interface { CreateClientConnection(conn net.Conn, connectionID uint64, tlsConfig *tls.Config, serverCapability uint32) ClientConnection + CreateBackendConnManager() BackendConnManager } diff --git a/pkg/proxy/driver/driver.go b/pkg/proxy/driver/driver.go index 34f1c526..53a080b4 100644 --- a/pkg/proxy/driver/driver.go +++ b/pkg/proxy/driver/driver.go @@ -6,20 +6,28 @@ import ( ) type createClientConnFunc func(QueryCtx, net.Conn, uint64, *tls.Config, uint32) ClientConnection +type createBackendConnMgrFunc func() BackendConnManager type DriverImpl struct { - nsmgr NamespaceManager - createClientConnFunc createClientConnFunc + nsmgr NamespaceManager + createClientConnFunc createClientConnFunc + createBackendConnMgrFunc createBackendConnMgrFunc } -func NewDriverImpl(nsmgr NamespaceManager, createClientConnFunc createClientConnFunc) *DriverImpl { +func NewDriverImpl(nsmgr NamespaceManager, createClientConnFunc createClientConnFunc, createBackendConnMgrFunc createBackendConnMgrFunc) *DriverImpl { return &DriverImpl{ - nsmgr: nsmgr, - createClientConnFunc: createClientConnFunc, + nsmgr: nsmgr, + createClientConnFunc: createClientConnFunc, + createBackendConnMgrFunc: createBackendConnMgrFunc, } } func (d *DriverImpl) CreateClientConnection(conn net.Conn, connectionID uint64, tlsConfig *tls.Config, serverCapability uint32) ClientConnection { - queryCtx := NewQueryCtxImpl(d.nsmgr, connectionID) + backendConnMgr := d.createBackendConnMgrFunc() + queryCtx := NewQueryCtxImpl(d.nsmgr, backendConnMgr, connectionID) return d.createClientConnFunc(queryCtx, conn, connectionID, tlsConfig, serverCapability) } + +func (d *DriverImpl) CreateBackendConnManager() BackendConnManager { + return d.createBackendConnMgrFunc() +} diff --git a/pkg/proxy/driver/queryctx.go b/pkg/proxy/driver/queryctx.go index 4fb554d6..3bd9609f 100644 --- a/pkg/proxy/driver/queryctx.go +++ b/pkg/proxy/driver/queryctx.go @@ -13,8 +13,8 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" gomysql "github.com/siddontang/go-mysql/mysql" - "github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/backend" wast "github.com/tidb-incubator/weir/pkg/util/ast" + wauth "github.com/tidb-incubator/weir/pkg/util/auth" cb "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker/circuit_breaker" ) @@ -41,15 +41,15 @@ type QueryCtxImpl struct { currentDB string parser *parser.Parser sessionVars *SessionVarsWrapper - connMgr *backend.BackendConnManager + connMgr BackendConnManager } -func NewQueryCtxImpl(nsmgr NamespaceManager, connId uint64) *QueryCtxImpl { +func NewQueryCtxImpl(nsmgr NamespaceManager, backendConnMgr BackendConnManager, connId uint64) *QueryCtxImpl { return &QueryCtxImpl{ connId: connId, nsmgr: nsmgr, parser: parser.New(), - connMgr: backend.NewBackendConnManager(), + connMgr: backendConnMgr, sessionVars: NewSessionVarsWrapper(variable.NewSessionVars()), } } @@ -174,16 +174,24 @@ func (q *QueryCtxImpl) Close() error { } func (q *QueryCtxImpl) Auth(user *auth.UserIdentity, authData []byte, salt []byte) error { + // Looks up namespace. ns, ok := q.nsmgr.Auth(user.Username, authData, salt) if !ok { return errors.New("Unrecognized user") } q.ns = ns - authInfo := &backend.AuthInfo{ - Username: user.Username, - AuthData: authData, + authInfo := &wauth.AuthInfo{ + Username: user.Username, + AuthString: authData, + } + addr, err := ns.GetRouter().Route() + if err != nil { + return err } q.connMgr.SetAuthInfo(authInfo) + if err = q.connMgr.Connect(addr); err != nil { + return err + } q.ns.IncrConnCount() return nil } diff --git a/pkg/proxy/namespace/builder.go b/pkg/proxy/namespace/builder.go index 22b3223c..a2985b58 100644 --- a/pkg/proxy/namespace/builder.go +++ b/pkg/proxy/namespace/builder.go @@ -19,7 +19,7 @@ type NamespaceImpl struct { Br driver.Breaker Backend Frontend - router router.Router + router driver.Router rateLimiter *NamespaceRateLimiter } @@ -70,11 +70,11 @@ func (n *NamespaceImpl) GetRateLimiter() driver.RateLimiter { return n.rateLimiter } -func (n *NamespaceImpl) GetRouter() router.Router { +func (n *NamespaceImpl) GetRouter() driver.Router { return n.router } -func BuildRouter(cfg *config.BackendNamespace) (router.Router, error) { +func BuildRouter(cfg *config.BackendNamespace) (driver.Router, error) { if len(cfg.Instances) == 0 { return nil, errors.New("no instances for the backend") } diff --git a/pkg/proxy/namespace/domain.go b/pkg/proxy/namespace/domain.go index e4d7d4ce..d8785029 100644 --- a/pkg/proxy/namespace/domain.go +++ b/pkg/proxy/namespace/domain.go @@ -17,6 +17,7 @@ type Namespace interface { Close() GetBreaker() (driver.Breaker, error) GetRateLimiter() driver.RateLimiter + GetRouter() driver.Router } type Frontend interface { diff --git a/pkg/proxy/namespace/namespace.go b/pkg/proxy/namespace/namespace.go index 454637be..8c29e286 100644 --- a/pkg/proxy/namespace/namespace.go +++ b/pkg/proxy/namespace/namespace.go @@ -108,6 +108,10 @@ func (n *NamespaceWrapper) GetRateLimiter() driver.RateLimiter { return n.mustGetCurrentNamespace().GetRateLimiter() } +func (n *NamespaceWrapper) GetRouter() driver.Router { + return n.mustGetCurrentNamespace().GetRouter() +} + func (n *NamespaceWrapper) mustGetCurrentNamespace() Namespace { ns, ok := n.nsmgr.getCurrentNamespaces().Get(n.name) if !ok { diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 5450f7a5..c38f8699 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -9,6 +9,7 @@ import ( "github.com/tidb-incubator/weir/pkg/proxy/metrics" "github.com/tidb-incubator/weir/pkg/proxy/namespace" "github.com/tidb-incubator/weir/pkg/proxy/server" + "github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/backend" "github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/client" ) @@ -53,7 +54,7 @@ func (p *Proxy) Init() error { return err } p.nsmgr = nsmgr - driverImpl := driver.NewDriverImpl(nsmgr, client.NewClientConnectionImpl) + driverImpl := driver.NewDriverImpl(nsmgr, client.NewClientConnectionImpl, backend.NewBackendConnManager) svr, err := server.NewServer(p.cfg, driverImpl) if err != nil { return err diff --git a/pkg/proxy/router/router.go b/pkg/proxy/router/router.go index b5dba986..20618742 100644 --- a/pkg/proxy/router/router.go +++ b/pkg/proxy/router/router.go @@ -9,18 +9,12 @@ var ( ErrNoInstanceToSelect = errors.New("no instances to route") ) -type Router interface { - SetAddresses([]string) - Route() (string, error) - AddConnOnAddr(string, int) -} - type RandomRouter struct { addresses []string addr2Conns map[string]int } -func NewRandomRouter() Router { +func NewRandomRouter() *RandomRouter { return &RandomRouter{ addr2Conns: make(map[string]int, 0), } diff --git a/pkg/proxy/sessionmgr/backend/backend_conn.go b/pkg/proxy/sessionmgr/backend/backend_conn.go index 2fa262fc..17ea9344 100644 --- a/pkg/proxy/sessionmgr/backend/backend_conn.go +++ b/pkg/proxy/sessionmgr/backend/backend_conn.go @@ -2,6 +2,7 @@ package backend import ( "github.com/siddontang/go-mysql/packet" + "github.com/tidb-incubator/weir/pkg/util/auth" ) type connectionPhase byte @@ -15,7 +16,7 @@ const ( ) type BackendConnection interface { - Connect(username string, authData []byte) error + Connect(authInfo *auth.AuthInfo) error Close() error } @@ -24,17 +25,17 @@ type BackendConnectionImpl struct { phase connectionPhase capability uint32 - server *BackendServer + address string } -func NewBackendConnectionImpl(backendServer *BackendServer) *BackendConnectionImpl { +func NewBackendConnectionImpl(address string) *BackendConnectionImpl { return &BackendConnectionImpl{ - phase: handshaking, - server: backendServer, + phase: handshaking, + address: address, } } -func (conn *BackendConnectionImpl) Connect(username string, authData []byte) error { +func (conn *BackendConnectionImpl) Connect(authInfo *auth.AuthInfo) error { return nil } diff --git a/pkg/proxy/sessionmgr/backend/backend_conn_mgr.go b/pkg/proxy/sessionmgr/backend/backend_conn_mgr.go index df9c0ed9..9d277633 100644 --- a/pkg/proxy/sessionmgr/backend/backend_conn_mgr.go +++ b/pkg/proxy/sessionmgr/backend/backend_conn_mgr.go @@ -4,6 +4,8 @@ import ( "context" gomysql "github.com/siddontang/go-mysql/mysql" + "github.com/tidb-incubator/weir/pkg/proxy/driver" + "github.com/tidb-incubator/weir/pkg/util/auth" ) type ConnectionPhase byte @@ -24,20 +26,14 @@ const ( StatusPrepareWaitFetch uint32 = 0x08 ) -type AuthInfo struct { - // user information obtained during authentication - Username string - AuthData []byte -} - type BackendConnManager struct { backendConn BackendConnection connectionPhase ConnectionPhase serverStatus uint32 - authInfo *AuthInfo + authInfo *auth.AuthInfo } -func NewBackendConnManager() *BackendConnManager { +func NewBackendConnManager() driver.BackendConnManager { return &BackendConnManager{ connectionPhase: InitBackend, serverStatus: StatusAutoCommit, @@ -56,18 +52,19 @@ func (mgr *BackendConnManager) Run(context context.Context) { } } -func (mgr *BackendConnManager) SetAuthInfo(authInfo *AuthInfo) { +func (mgr *BackendConnManager) SetAuthInfo(authInfo *auth.AuthInfo) { mgr.authInfo = authInfo } -func (mgr *BackendConnManager) Connect(server *BackendServer) error { +func (mgr *BackendConnManager) Connect(address string) error { + // It may be still connecting to the original backend server. if mgr.backendConn != nil { if err := mgr.backendConn.Close(); err != nil { return err } } - mgr.backendConn = NewBackendConnectionImpl(server) - return mgr.backendConn.Connect(mgr.authInfo.Username, mgr.authInfo.AuthData) + mgr.backendConn = NewBackendConnectionImpl(address) + return mgr.backendConn.Connect(mgr.authInfo) } func (mgr *BackendConnManager) initSessionStates() error { diff --git a/pkg/util/auth/auth_info.go b/pkg/util/auth/auth_info.go new file mode 100644 index 00000000..8b4586ac --- /dev/null +++ b/pkg/util/auth/auth_info.go @@ -0,0 +1,13 @@ +package auth + +// AuthInfo the user information that is stored temporarily in the proxy. +type AuthInfo struct { + // user information obtained during authentication + Username string + AuthPlugin string + AuthString []byte // password that sent from the client + BackendSalt []byte // backend salt used to encrypt password + Token []byte // or password + + DefaultDB string +}