diff --git a/plugin/plugin.go b/plugin/plugin.go index fd02eb6a86d88..f0093dfef60bb 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -374,6 +374,21 @@ func ForeachPlugin(kind Kind, fn func(plugin *Plugin) error) error { return nil } +// IsEnable checks plugin's enable state. +func IsEnable(kind Kind) bool { + plugins := pluginGlobal.plugins() + if plugins == nil { + return false + } + for i := range plugins.plugins[kind] { + p := &plugins.plugins[kind][i] + if p.State == Ready { + return true + } + } + return false +} + // GetAll finds and returns all plugins. func GetAll() map[Kind][]Plugin { plugins := pluginGlobal.plugins() diff --git a/server/conn.go b/server/conn.go index 8213c6a07432f..d22f6f5c91e10 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1449,10 +1449,14 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { return err } + if plugin.IsEnable(plugin.Audit) { + cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo() + } + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := cc.connectInfo() + connInfo := cc.ctx.GetSessionVars().ConnectionInfo err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: connInfo.Host}, plugin.ChangeUser, connInfo) if err != nil { return err diff --git a/server/server.go b/server/server.go index 906fc0e1fc518..12a3e303cc841 100644 --- a/server/server.go +++ b/server/server.go @@ -422,11 +422,14 @@ func (s *Server) onConn(conn *clientConn) { s.rwlock.Unlock() metrics.ConnGauge.Set(float64(connections)) + if plugin.IsEnable(plugin.Audit) { + conn.ctx.GetSessionVars().ConnectionInfo = conn.connectInfo() + } err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := conn.connectInfo() - return authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Connected, connInfo) + sessionVars := conn.ctx.GetSessionVars() + return authPlugin.OnConnectionEvent(context.Background(), sessionVars.User, plugin.Connected, sessionVars.ConnectionInfo) } return nil }) @@ -440,9 +443,9 @@ func (s *Server) onConn(conn *clientConn) { err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := conn.connectInfo() - connInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) - err := authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Disconnect, connInfo) + sessionVars := conn.ctx.GetSessionVars() + sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) + err := authPlugin.OnConnectionEvent(context.Background(), sessionVars.User, plugin.Disconnect, sessionVars.ConnectionInfo) if err != nil { logutil.BgLogger().Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err)) } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 91e854b1c8851..fcc1a03c72a3f 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -389,6 +389,9 @@ type SessionVars struct { // Killed is a flag to indicate that this query is killed. Killed uint32 + + // ConnectionInfo indicates current connection info used by current session, only be lazy assigned by plugin. + ConnectionInfo *ConnectionInfo } // ConnectionInfo present connection used by audit.