Skip to content

Commit

Permalink
Merge pull request #2 from djshow832/auth_and_ns
Browse files Browse the repository at this point in the history
*: make handshake be compatible with namespace
  • Loading branch information
djshow832 authored Mar 17, 2022
2 parents be4de2c + f0deaba commit dfbb217
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 43 deletions.
12 changes: 11 additions & 1 deletion pkg/proxy/driver/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
"github.com/siddontang/go-mysql/mysql"
wauth "github.com/tidb-incubator/weir/pkg/util/auth"
)

type NamespaceManager interface {
Expand All @@ -27,6 +28,7 @@ type Namespace interface {
DescConnCount()
GetBreaker() (Breaker, error)
GetRateLimiter() RateLimiter
GetRouter() Router
}

type Breaker interface {
Expand All @@ -45,6 +47,12 @@ type RateLimiter interface {
Limit(ctx context.Context, key string) error
}

type Router interface {
SetAddresses([]string)
Route() (string, error)
AddConnOnAddr(string, int)
}

type PooledBackendConn interface {
// PutBack put conn back to pool
PutBack()
Expand Down Expand Up @@ -95,9 +103,10 @@ type ClientConnection interface {
}

type BackendConnManager interface {
SetAuthInfo(username string, authData []byte)
SetAuthInfo(authInfo *wauth.AuthInfo)
Connect(address string) error
Query(ctx context.Context, sql string) (*mysql.Result, error)
Close() error
}

// QueryCtx is the interface to execute command.
Expand Down Expand Up @@ -134,4 +143,5 @@ type QueryCtx interface {

type IDriver interface {
CreateClientConnection(conn net.Conn, connectionID uint64, tlsConfig *tls.Config, serverCapability uint32) ClientConnection
CreateBackendConnManager() BackendConnManager
}
20 changes: 14 additions & 6 deletions pkg/proxy/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@ import (
)

type createClientConnFunc func(QueryCtx, net.Conn, uint64, *tls.Config, uint32) ClientConnection
type createBackendConnMgrFunc func() BackendConnManager

type DriverImpl struct {
nsmgr NamespaceManager
createClientConnFunc createClientConnFunc
nsmgr NamespaceManager
createClientConnFunc createClientConnFunc
createBackendConnMgrFunc createBackendConnMgrFunc
}

func NewDriverImpl(nsmgr NamespaceManager, createClientConnFunc createClientConnFunc) *DriverImpl {
func NewDriverImpl(nsmgr NamespaceManager, createClientConnFunc createClientConnFunc, createBackendConnMgrFunc createBackendConnMgrFunc) *DriverImpl {
return &DriverImpl{
nsmgr: nsmgr,
createClientConnFunc: createClientConnFunc,
nsmgr: nsmgr,
createClientConnFunc: createClientConnFunc,
createBackendConnMgrFunc: createBackendConnMgrFunc,
}
}

func (d *DriverImpl) CreateClientConnection(conn net.Conn, connectionID uint64, tlsConfig *tls.Config, serverCapability uint32) ClientConnection {
queryCtx := NewQueryCtxImpl(d.nsmgr, connectionID)
backendConnMgr := d.createBackendConnMgrFunc()
queryCtx := NewQueryCtxImpl(d.nsmgr, backendConnMgr, connectionID)
return d.createClientConnFunc(queryCtx, conn, connectionID, tlsConfig, serverCapability)
}

func (d *DriverImpl) CreateBackendConnManager() BackendConnManager {
return d.createBackendConnMgrFunc()
}
22 changes: 15 additions & 7 deletions pkg/proxy/driver/queryctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
gomysql "github.com/siddontang/go-mysql/mysql"
"github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/backend"
wast "github.com/tidb-incubator/weir/pkg/util/ast"
wauth "github.com/tidb-incubator/weir/pkg/util/auth"
cb "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker/circuit_breaker"
)

Expand All @@ -41,15 +41,15 @@ type QueryCtxImpl struct {
currentDB string
parser *parser.Parser
sessionVars *SessionVarsWrapper
connMgr *backend.BackendConnManager
connMgr BackendConnManager
}

func NewQueryCtxImpl(nsmgr NamespaceManager, connId uint64) *QueryCtxImpl {
func NewQueryCtxImpl(nsmgr NamespaceManager, backendConnMgr BackendConnManager, connId uint64) *QueryCtxImpl {
return &QueryCtxImpl{
connId: connId,
nsmgr: nsmgr,
parser: parser.New(),
connMgr: backend.NewBackendConnManager(),
connMgr: backendConnMgr,
sessionVars: NewSessionVarsWrapper(variable.NewSessionVars()),
}
}
Expand Down Expand Up @@ -174,16 +174,24 @@ func (q *QueryCtxImpl) Close() error {
}

func (q *QueryCtxImpl) Auth(user *auth.UserIdentity, authData []byte, salt []byte) error {
// Looks up namespace.
ns, ok := q.nsmgr.Auth(user.Username, authData, salt)
if !ok {
return errors.New("Unrecognized user")
}
q.ns = ns
authInfo := &backend.AuthInfo{
Username: user.Username,
AuthData: authData,
authInfo := &wauth.AuthInfo{
Username: user.Username,
AuthString: authData,
}
addr, err := ns.GetRouter().Route()
if err != nil {
return err
}
q.connMgr.SetAuthInfo(authInfo)
if err = q.connMgr.Connect(addr); err != nil {
return err
}
q.ns.IncrConnCount()
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/proxy/namespace/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type NamespaceImpl struct {
Br driver.Breaker
Backend
Frontend
router router.Router
router driver.Router
rateLimiter *NamespaceRateLimiter
}

Expand Down Expand Up @@ -70,11 +70,11 @@ func (n *NamespaceImpl) GetRateLimiter() driver.RateLimiter {
return n.rateLimiter
}

func (n *NamespaceImpl) GetRouter() router.Router {
func (n *NamespaceImpl) GetRouter() driver.Router {
return n.router
}

func BuildRouter(cfg *config.BackendNamespace) (router.Router, error) {
func BuildRouter(cfg *config.BackendNamespace) (driver.Router, error) {
if len(cfg.Instances) == 0 {
return nil, errors.New("no instances for the backend")
}
Expand Down
1 change: 1 addition & 0 deletions pkg/proxy/namespace/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Namespace interface {
Close()
GetBreaker() (driver.Breaker, error)
GetRateLimiter() driver.RateLimiter
GetRouter() driver.Router
}

type Frontend interface {
Expand Down
4 changes: 4 additions & 0 deletions pkg/proxy/namespace/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ func (n *NamespaceWrapper) GetRateLimiter() driver.RateLimiter {
return n.mustGetCurrentNamespace().GetRateLimiter()
}

func (n *NamespaceWrapper) GetRouter() driver.Router {
return n.mustGetCurrentNamespace().GetRouter()
}

func (n *NamespaceWrapper) mustGetCurrentNamespace() Namespace {
ns, ok := n.nsmgr.getCurrentNamespaces().Get(n.name)
if !ok {
Expand Down
3 changes: 2 additions & 1 deletion pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/tidb-incubator/weir/pkg/proxy/metrics"
"github.com/tidb-incubator/weir/pkg/proxy/namespace"
"github.com/tidb-incubator/weir/pkg/proxy/server"
"github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/backend"
"github.com/tidb-incubator/weir/pkg/proxy/sessionmgr/client"
)

Expand Down Expand Up @@ -53,7 +54,7 @@ func (p *Proxy) Init() error {
return err
}
p.nsmgr = nsmgr
driverImpl := driver.NewDriverImpl(nsmgr, client.NewClientConnectionImpl)
driverImpl := driver.NewDriverImpl(nsmgr, client.NewClientConnectionImpl, backend.NewBackendConnManager)
svr, err := server.NewServer(p.cfg, driverImpl)
if err != nil {
return err
Expand Down
8 changes: 1 addition & 7 deletions pkg/proxy/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,12 @@ var (
ErrNoInstanceToSelect = errors.New("no instances to route")
)

type Router interface {
SetAddresses([]string)
Route() (string, error)
AddConnOnAddr(string, int)
}

type RandomRouter struct {
addresses []string
addr2Conns map[string]int
}

func NewRandomRouter() Router {
func NewRandomRouter() *RandomRouter {
return &RandomRouter{
addr2Conns: make(map[string]int, 0),
}
Expand Down
13 changes: 7 additions & 6 deletions pkg/proxy/sessionmgr/backend/backend_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package backend

import (
"github.com/siddontang/go-mysql/packet"
"github.com/tidb-incubator/weir/pkg/util/auth"
)

type connectionPhase byte
Expand All @@ -15,7 +16,7 @@ const (
)

type BackendConnection interface {
Connect(username string, authData []byte) error
Connect(authInfo *auth.AuthInfo) error
Close() error
}

Expand All @@ -24,17 +25,17 @@ type BackendConnectionImpl struct {

phase connectionPhase
capability uint32
server *BackendServer
address string
}

func NewBackendConnectionImpl(backendServer *BackendServer) *BackendConnectionImpl {
func NewBackendConnectionImpl(address string) *BackendConnectionImpl {
return &BackendConnectionImpl{
phase: handshaking,
server: backendServer,
phase: handshaking,
address: address,
}
}

func (conn *BackendConnectionImpl) Connect(username string, authData []byte) error {
func (conn *BackendConnectionImpl) Connect(authInfo *auth.AuthInfo) error {
return nil
}

Expand Down
21 changes: 9 additions & 12 deletions pkg/proxy/sessionmgr/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"

gomysql "github.com/siddontang/go-mysql/mysql"
"github.com/tidb-incubator/weir/pkg/proxy/driver"
"github.com/tidb-incubator/weir/pkg/util/auth"
)

type ConnectionPhase byte
Expand All @@ -24,20 +26,14 @@ const (
StatusPrepareWaitFetch uint32 = 0x08
)

type AuthInfo struct {
// user information obtained during authentication
Username string
AuthData []byte
}

type BackendConnManager struct {
backendConn BackendConnection
connectionPhase ConnectionPhase
serverStatus uint32
authInfo *AuthInfo
authInfo *auth.AuthInfo
}

func NewBackendConnManager() *BackendConnManager {
func NewBackendConnManager() driver.BackendConnManager {
return &BackendConnManager{
connectionPhase: InitBackend,
serverStatus: StatusAutoCommit,
Expand All @@ -56,18 +52,19 @@ func (mgr *BackendConnManager) Run(context context.Context) {
}
}

func (mgr *BackendConnManager) SetAuthInfo(authInfo *AuthInfo) {
func (mgr *BackendConnManager) SetAuthInfo(authInfo *auth.AuthInfo) {
mgr.authInfo = authInfo
}

func (mgr *BackendConnManager) Connect(server *BackendServer) error {
func (mgr *BackendConnManager) Connect(address string) error {
// It may be still connecting to the original backend server.
if mgr.backendConn != nil {
if err := mgr.backendConn.Close(); err != nil {
return err
}
}
mgr.backendConn = NewBackendConnectionImpl(server)
return mgr.backendConn.Connect(mgr.authInfo.Username, mgr.authInfo.AuthData)
mgr.backendConn = NewBackendConnectionImpl(address)
return mgr.backendConn.Connect(mgr.authInfo)
}

func (mgr *BackendConnManager) initSessionStates() error {
Expand Down
13 changes: 13 additions & 0 deletions pkg/util/auth/auth_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package auth

// AuthInfo the user information that is stored temporarily in the proxy.
type AuthInfo struct {
// user information obtained during authentication
Username string
AuthPlugin string
AuthString []byte // password that sent from the client
BackendSalt []byte // backend salt used to encrypt password
Token []byte // or password

DefaultDB string
}

0 comments on commit dfbb217

Please sign in to comment.