From 1b599331c8519dbc7ed8644cb706da17d9e931fd Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Mon, 7 Dec 2020 17:24:56 -0700 Subject: [PATCH] Initial support for init_connect --- server/conn.go | 54 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/server/conn.go b/server/conn.go index 9ac40d3b9e286..74d29e79b0a5d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -721,7 +721,10 @@ func (cc *clientConn) PeerHost(hasPassword string) (host string, err error) { } func (cc *clientConn) initConnect(ctx context.Context) error { - val, _ := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect) + val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect) + if err != nil { + return err + } if val == "" { return nil } @@ -732,23 +735,33 @@ func (cc *clientConn) initConnect(ctx context.Context) error { if checker != nil && checker.RequestVerification(activeRoles, "", "", "", mysql.SuperPriv) { return nil } - logutil.Logger(ctx).Debug("executing init_connect", zap.String("initConnect", val)) + stmts, err := cc.ctx.Parse(ctx, val) + if err != nil { + logutil.Logger(ctx).Warn("failed to parse init_connect", zap.Error(err)) + return err + } // We don't care about the results, but the semantics of // init_connect requires that they are fully read. - rss, err := cc.ctx.Execute(ctx, val) - for _, rs := range rss { - req := rs.NewChunk() - for { - _ = rs.Next(ctx, req) - if req.NumRows() == 0 { - break + for _, stmt := range stmts { + rs, err1 := cc.ctx.ExecuteStmt(ctx, stmt) + if err1 != nil { + logutil.Logger(ctx).Warn("init_connect stmt failed", zap.Error(err1)) + return err + } + if rs != nil { // it could have been a SET stmt + req := rs.NewChunk() + for { + if err = rs.Next(ctx, req); err != nil { + return err + } + if req.NumRows() == 0 { + break + } + } + if err := rs.Close(); err != nil { + return err } } - rs.Close() - } - if err != nil { - logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err)) - return err } logutil.Logger(ctx).Debug("init_connect complete") return nil @@ -783,10 +796,9 @@ func (cc *clientConn) Run(ctx context.Context) { // MySQL supports an "init_connect" query, which can be run on initial connection. // The query must return a non-error or the client is disconnected. // It is not executed for SUPER users. + initConnectFailed := false if err := cc.initConnect(ctx); err != nil { - initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed") - cc.writeError(ctx, initErr) - return + initConnectFailed = true } // Usually, client connection status changes between [dispatching] <=> [reading]. @@ -808,6 +820,14 @@ func (cc *clientConn) Run(ctx context.Context) { cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second) start := time.Now() data, err := cc.readPacket() + if initConnectFailed { + disconnectErrorUndetermined.Inc() + initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed") + if err1 := cc.writeError(ctx, initErr); err1 != nil { + terror.Log(err1) + } + return + } if err != nil { if terror.ErrorNotEqual(err, io.EOF) { if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() {