Skip to content

Commit

Permalink
Initial support for init_connect
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo committed Dec 8, 2020
1 parent accf775 commit 1b59933
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,10 @@ func (cc *clientConn) PeerHost(hasPassword string) (host string, err error) {
}

func (cc *clientConn) initConnect(ctx context.Context) error {
val, _ := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
if err != nil {
return err
}
if val == "" {
return nil
}
Expand All @@ -732,23 +735,33 @@ func (cc *clientConn) initConnect(ctx context.Context) error {
if checker != nil && checker.RequestVerification(activeRoles, "", "", "", mysql.SuperPriv) {
return nil
}
logutil.Logger(ctx).Debug("executing init_connect", zap.String("initConnect", val))
stmts, err := cc.ctx.Parse(ctx, val)
if err != nil {
logutil.Logger(ctx).Warn("failed to parse init_connect", zap.Error(err))
return err
}
// We don't care about the results, but the semantics of
// init_connect requires that they are fully read.
rss, err := cc.ctx.Execute(ctx, val)
for _, rs := range rss {
req := rs.NewChunk()
for {
_ = rs.Next(ctx, req)
if req.NumRows() == 0 {
break
for _, stmt := range stmts {
rs, err1 := cc.ctx.ExecuteStmt(ctx, stmt)
if err1 != nil {
logutil.Logger(ctx).Warn("init_connect stmt failed", zap.Error(err1))
return err
}
if rs != nil { // it could have been a SET stmt
req := rs.NewChunk()
for {
if err = rs.Next(ctx, req); err != nil {
return err
}
if req.NumRows() == 0 {
break
}
}
if err := rs.Close(); err != nil {
return err
}
}
rs.Close()
}
if err != nil {
logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err))
return err
}
logutil.Logger(ctx).Debug("init_connect complete")
return nil
Expand Down Expand Up @@ -783,10 +796,9 @@ func (cc *clientConn) Run(ctx context.Context) {
// MySQL supports an "init_connect" query, which can be run on initial connection.
// The query must return a non-error or the client is disconnected.
// It is not executed for SUPER users.
initConnectFailed := false
if err := cc.initConnect(ctx); err != nil {
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
cc.writeError(ctx, initErr)
return
initConnectFailed = true
}

// Usually, client connection status changes between [dispatching] <=> [reading].
Expand All @@ -808,6 +820,14 @@ func (cc *clientConn) Run(ctx context.Context) {
cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second)
start := time.Now()
data, err := cc.readPacket()
if initConnectFailed {
disconnectErrorUndetermined.Inc()
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
if err1 := cc.writeError(ctx, initErr); err1 != nil {
terror.Log(err1)
}
return
}
if err != nil {
if terror.ErrorNotEqual(err, io.EOF) {
if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() {
Expand Down

0 comments on commit 1b59933

Please sign in to comment.