Skip to content

Commit

Permalink
extension: support extension to listen connection events (#38624)
Browse files Browse the repository at this point in the history
close #38623
  • Loading branch information
lcwangchao authored Oct 26, 2022
1 parent d27c706 commit 753e3da
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 25 deletions.
8 changes: 8 additions & 0 deletions extension/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ func (es *Extensions) Bootstrap(ctx BootstrapContext) error {
}
return nil
}

// NewSessionExtensions creates a new ConnExtensions object
func (es *Extensions) NewSessionExtensions() *SessionExtensions {
if es == nil {
return nil
}
return newSessionExtensions(es)
}
20 changes: 14 additions & 6 deletions extension/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ func WithCustomFunctions(funcs []*FunctionDef) Option {
}
}

// WithSessionHandlerFactory specifies a factory function to handle session
func WithSessionHandlerFactory(factory func() *SessionHandler) Option {
return func(m *Manifest) {
m.sessionHandlerFactory = factory
}
}

// WithClose specifies the close function of an extension.
// It will be invoked when `extension.Reset` is called
func WithClose(fn func()) Option {
Expand Down Expand Up @@ -82,12 +89,13 @@ func WithBootstrapSQL(sqlList ...string) Option {

// Manifest is an extension's manifest
type Manifest struct {
name string
sysVariables []*variable.SysVar
dynPrivs []string
bootstrap func(BootstrapContext) error
funcs []*FunctionDef
close func()
name string
sysVariables []*variable.SysVar
dynPrivs []string
bootstrap func(BootstrapContext) error
funcs []*FunctionDef
sessionHandlerFactory func() *SessionHandler
close func()
}

// Name returns the extension's name
Expand Down
72 changes: 72 additions & 0 deletions extension/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package extension

import "github.com/pingcap/tidb/sessionctx/variable"

// ConnEventInfo is the connection info for the event
type ConnEventInfo variable.ConnectionInfo

// ConnEventTp is the type of the connection event
type ConnEventTp uint8

const (
// ConnConnected means connection connected, but not handshake yet
ConnConnected ConnEventTp = iota
// ConnHandshakeAccepted means connection is accepted after handshake
ConnHandshakeAccepted
// ConnHandshakeRejected means connections is rejected after handshake
ConnHandshakeRejected
// ConnReset means the connection is reset
ConnReset
// ConnDisconnected means the connection is disconnected
ConnDisconnected
)

// SessionHandler is used to listen session events
type SessionHandler struct {
OnConnectionEvent func(ConnEventTp, *ConnEventInfo)
}

func newSessionExtensions(es *Extensions) *SessionExtensions {
connExtensions := &SessionExtensions{}
for _, m := range es.Manifests() {
if m.sessionHandlerFactory != nil {
if handler := m.sessionHandlerFactory(); handler != nil {
if fn := handler.OnConnectionEvent; fn != nil {
connExtensions.connectionEventFuncs = append(connExtensions.connectionEventFuncs, fn)
}
}
}
}
return connExtensions
}

// SessionExtensions is the extensions
type SessionExtensions struct {
connectionEventFuncs []func(ConnEventTp, *ConnEventInfo)
}

// OnConnectionEvent will be called when a connection event happens
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, info *variable.ConnectionInfo) {
if es == nil {
return
}

eventInfo := ConnEventInfo(*info)
for _, fn := range es.connectionEventFuncs {
fn(tp, &eventInfo)
}
}
11 changes: 8 additions & 3 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -212,6 +213,7 @@ type clientConn struct {
sync.RWMutex
cancelFunc context.CancelFunc
}
extensions *extension.SessionExtensions
}

func (cc *clientConn) getCtx() *TiDBContext {
Expand Down Expand Up @@ -792,7 +794,7 @@ func (cc *clientConn) openSession() error {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
ctx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr)
ctx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions)
if err != nil {
return err
}
Expand Down Expand Up @@ -2477,7 +2479,7 @@ func (cc *clientConn) handleResetConnection(ctx context.Context) error {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
tidbCtx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr)
tidbCtx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions)
if err != nil {
return err
}
Expand All @@ -2497,7 +2499,10 @@ func (cc *clientConn) handleResetConnection(ctx context.Context) error {
}

func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo()
connectionInfo := cc.connectInfo()
cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo

cc.extensions.OnConnectionEvent(extension.ConnReset, connectionInfo)

err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
Expand Down
125 changes: 124 additions & 1 deletion server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/session"
Expand Down Expand Up @@ -1425,7 +1426,7 @@ func TestAuthTokenPlugin(t *testing.T) {
tk.MustExec("CREATE USER auth_session_token")
tk.MustExec("CREATE USER another_user")

tc, err := drv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "", nil)
tc, err := drv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "", nil, nil)
require.NoError(t, err)
cc := &clientConn{
connectionID: 1,
Expand Down Expand Up @@ -1584,3 +1585,125 @@ func TestOkEof(t *testing.T) {
require.Equal(t, mysql.EOFHeader, outBuffer.Bytes()[4])
require.Equal(t, []byte{0x7, 0x0, 0x0, 0x1, 0xfe, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0}, outBuffer.Bytes())
}

func TestExtensionChangeUser(t *testing.T) {
defer extension.Reset()
extension.Reset()

logged := false
var logTp extension.ConnEventTp
var logInfo *extension.ConnEventInfo
require.NoError(t, extension.Register("test", extension.WithSessionHandlerFactory(func() *extension.SessionHandler {
return &extension.SessionHandler{
OnConnectionEvent: func(tp extension.ConnEventTp, info *extension.ConnEventInfo) {
require.False(t, logged)
logTp = tp
logInfo = info
logged = true
},
}
})))

extensions, err := extension.GetExtensions()
require.NoError(t, err)

store := testkit.CreateMockStore(t)

var outBuffer bytes.Buffer
tidbdrv := NewTiDBDriver(store)
cfg := newTestConfig()
cfg.Port, cfg.Status.StatusPort = 0, 0
cfg.Status.ReportStatus = false
server, err := NewServer(cfg, tidbdrv)
require.NoError(t, err)
defer server.Close()

cc := &clientConn{
connectionID: 1,
salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
server: server,
pkt: &packetIO{
bufWriter: bufio.NewWriter(&outBuffer),
},
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
extensions: extensions.NewSessionExtensions(),
}

tk := testkit.NewTestKit(t, store)
ctx := &TiDBContext{Session: tk.Session()}
cc.setCtx(ctx)
tk.MustExec("create user user1")
tk.MustExec("create user user2")
tk.MustExec("create database db1")
tk.MustExec("create database db2")
tk.MustExec("grant select on db1.* to user1@'%'")
tk.MustExec("grant select on db2.* to user2@'%'")

// change user.
doDispatch := func(req dispatchInput) {
inBytes := append([]byte{req.com}, req.in...)
err = cc.dispatch(context.Background(), inBytes)
require.Equal(t, req.err, err)
if err == nil {
err = cc.flush(context.TODO())
require.NoError(t, err)
require.Equal(t, req.out, outBuffer.Bytes())
} else {
_ = cc.flush(context.TODO())
}
outBuffer.Reset()
}

expectedConnInfo := extension.ConnEventInfo(*cc.connectInfo())
expectedConnInfo.User = "user1"
expectedConnInfo.DB = "db1"

require.False(t, logged)
userData := append([]byte("user1"), 0x0, 0x0)
userData = append(userData, []byte("db1")...)
userData = append(userData, 0x0)
doDispatch(dispatchInput{
com: mysql.ComChangeUser,
in: userData,
err: nil,
out: []byte{0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)

logged = false
logTp = 0
logInfo = nil
expectedConnInfo.User = "user2"
expectedConnInfo.DB = "db2"
userData = append([]byte("user2"), 0x0, 0x0)
userData = append(userData, []byte("db2")...)
userData = append(userData, 0x0)
doDispatch(dispatchInput{
com: mysql.ComChangeUser,
in: userData,
err: nil,
out: []byte{0x7, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)

logged = false
logTp = 0
logInfo = nil
doDispatch(dispatchInput{
com: mysql.ComResetConnection,
in: nil,
err: nil,
out: []byte{0x7, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
}
3 changes: 2 additions & 1 deletion server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ import (
"crypto/tls"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/util/chunk"
)

// IDriver opens IContext.
type IDriver interface {
// OpenCtx opens an IContext with connection id, client capability, collation, dbname and optionally the tls state.
OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (*TiDBContext, error)
OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error)
}

// PreparedStatement is the interface to use a prepared statement.
Expand Down
4 changes: 3 additions & 1 deletion server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/charset"
Expand Down Expand Up @@ -191,7 +192,7 @@ func (ts *TiDBStatement) Close() error {
}

// OpenCtx implements IDriver.
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (*TiDBContext, error) {
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) {
se, err := session.CreateSession(qd.store)
if err != nil {
return nil, err
Expand All @@ -208,6 +209,7 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8,
stmts: make(map[int]*TiDBStatement),
}
se.SetSessionStatesHandler(sessionstates.StatePrepareStmt, tc)
se.SetExtensions(extensions)
return tc, nil
}

Expand Down
2 changes: 1 addition & 1 deletion server/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func CreateMockServer(t *testing.T, store kv.Storage) *Server {
// CreateMockConn creates a mock connection together with a session.
func CreateMockConn(t *testing.T, server *Server) MockConn {
connID := rand.Uint64()
tc, err := server.driver.OpenCtx(connID, 0, uint8(tmysql.DefaultCollationID), "", nil)
tc, err := server.driver.OpenCtx(connID, 0, uint8(tmysql.DefaultCollationID), "", nil, nil)
require.NoError(t, err)

cc := &clientConn{
Expand Down
Loading

0 comments on commit 753e3da

Please sign in to comment.