Skip to content

Commit

Permalink
*: Add support for MAX_EXECUTION_TIME (#10541) (#10963) (#11026)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored and winkyao committed Jul 15, 2019
1 parent 5433468 commit 3ad5ff6
Show file tree
Hide file tree
Showing 17 changed files with 235 additions and 60 deletions.
19 changes: 17 additions & 2 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import (

// processinfoSetter is the interface use to set current running process info.
type processinfoSetter interface {
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
}

// recordSet wraps an executor, implements sqlexec.RecordSet interface
Expand Down Expand Up @@ -245,8 +245,9 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
sql = ss.SecureText()
}
}
maxExecutionTime := getMaxExecutionTime(sctx, a.StmtNode)
// Update processinfo, ShowProcess() will use it.
pi.SetProcessInfo(sql, time.Now(), cmd)
pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime)
a.Ctx.GetSessionVars().StmtCtx.StmtType = GetStmtLabel(a.StmtNode)
}

Expand Down Expand Up @@ -285,6 +286,20 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
}, nil
}

// getMaxExecutionTime get the max execution timeout value.
func getMaxExecutionTime(sctx sessionctx.Context, stmtNode ast.StmtNode) uint64 {
ret := sctx.GetSessionVars().MaxExecutionTime
if sel, ok := stmtNode.(*ast.SelectStmt); ok {
for _, hint := range sel.TableHints {
if hint.HintName.L == variable.MaxExecutionTime {
ret = hint.MaxExecutionTime
break
}
}
}
return ret
}

type chunkRowRecordSet struct {
rows []chunk.Row
idx int
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ golang.org/x/tools v0.0.0-20190130214255-bb1329dc71a0/go.mod h1:n7NCudcB/nEzxVGm
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/genproto v0.0.0-20180608181217-32ee49c4dd80/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f h1:FU37niK8AQ59mHcskRyQL7H0ErSeNh650vdcj8HqdSI=
google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275 h1:9oFlwfEGIvmxXTcY53ygNyxIQtWciRHjrnUvZJCYXYU=
google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
Expand Down
22 changes: 19 additions & 3 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
Expand Down Expand Up @@ -262,6 +263,11 @@ func (cc *clientConn) readPacket() ([]byte, error) {
}

func (cc *clientConn) writePacket(data []byte) error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.writePacket(data)
}

Expand Down Expand Up @@ -845,7 +851,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
cc.lastCmd = string(hack.String(data))
token := cc.server.getToken()
defer func() {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep)
// if handleChangeUser failed, cc.ctx may be nil
if cc.ctx != nil {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep, 0)
}

cc.server.releaseToken(token)
span.Finish()
}()
Expand All @@ -860,9 +870,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
switch cmd {
case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset,
mysql.ComSetOption, mysql.ComChangeUser:
cc.ctx.SetProcessInfo("", t, cmd)
cc.ctx.SetProcessInfo("", t, cmd, 0)
case mysql.ComInitDB:
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd)
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0)
}

switch cmd {
Expand Down Expand Up @@ -925,6 +935,11 @@ func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
}

func (cc *clientConn) flush() error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.flush()
}

Expand Down Expand Up @@ -1255,6 +1270,7 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b
if err != nil {
return err
}

return cc.flush()
}

Expand Down
2 changes: 1 addition & 1 deletion server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err
if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok {
sql = prepared.sql
}
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute)
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0)
rs := stmt.GetResultSet()
if rs == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
Expand Down
77 changes: 77 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/util/arena"
)

type ConnTestSuite struct {
Expand Down Expand Up @@ -207,3 +211,76 @@ func mapBelong(m1, m2 map[string]string) bool {
}
return true
}

func (ts ConnTestSuite) TestConnExecutionTimeout(c *C) {
//There is no underlying netCon, use failpoint to avoid panic
c.Assert(failpoint.Enable("github.com/pingcap/tidb/server/FakeClientConn", "return(1)"), IsNil)

c.Parallel()
var err error
ts.store, err = mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
ts.dom, err = session.BootstrapSession(ts.store)
c.Assert(err, IsNil)
se, err := session.CreateSession4Test(ts.store)
c.Assert(err, IsNil)

connID := 1
se.SetConnectionID(uint64(connID))
tc := &TiDBContext{
session: se,
stmts: make(map[int]*TiDBStatement),
}
cc := &clientConn{
connectionID: uint32(connID),
server: &Server{
capability: defaultCapability,
},
ctx: tc,
alloc: arena.NewAllocator(32 * 1024),
}
srv := &Server{
clients: map[uint32]*clientConn{
uint32(connID): cc,
},
}
handle := ts.dom.ExpensiveQueryHandle().SetSessionManager(srv)
go handle.Run()
defer handle.Close()

_, err = se.Execute(context.Background(), "use test;")
c.Assert(err, IsNil)
_, err = se.Execute(context.Background(), "CREATE TABLE testTable2 (id bigint PRIMARY KEY, age int)")
c.Assert(err, IsNil)
for i := 0; i < 10; i++ {
str := fmt.Sprintf("insert into testTable2 values(%d, %d)", i, i%80)
_, err = se.Execute(context.Background(), str)
c.Assert(err, IsNil)
}

_, err = se.Execute(context.Background(), "select SLEEP(1);")
c.Assert(err, IsNil)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 500;")
c.Assert(err, IsNil)

now := time.Now()
err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(3);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) < 3*time.Second, IsTrue)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 0;")
c.Assert(err, IsNil)

now = time.Now()
err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(1);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) > 500*time.Millisecond, IsTrue)

now = time.Now()
err = cc.handleQuery(context.Background(), "select /*+ MAX_EXECUTION_TIME(100)*/ * FROM testTable2 WHERE SLEEP(3);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) < 3*time.Second, IsTrue)

c.Assert(failpoint.Disable("github.com/pingcap/tidb/server/FakeClientConn"), IsNil)
}
2 changes: 1 addition & 1 deletion server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type QueryCtx interface {
// SetValue saves a value associated with this context for key.
SetValue(key fmt.Stringer, value interface{})

SetProcessInfo(sql string, t time.Time, command byte)
SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64)

// CommitTxn commits the transaction operations.
CommitTxn(ctx context.Context) error
Expand Down
4 changes: 2 additions & 2 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func (tc *TiDBContext) CommitTxn(ctx context.Context) error {
}

// SetProcessInfo implements QueryCtx SetProcessInfo method.
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte) {
tc.session.SetProcessInfo(sql, t, command)
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
tc.session.SetProcessInfo(sql, t, command, maxExecutionTime)
}

// RollbackTxn implements QueryCtx RollbackTxn method.
Expand Down
26 changes: 14 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand All @@ -106,7 +107,7 @@ type Server struct {
driver IDriver
listener net.Listener
socket net.Listener
rwlock *sync.RWMutex
rwlock sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
capability uint32
Expand Down Expand Up @@ -198,7 +199,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
cfg: cfg,
driver: driver,
concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
rwlock: &sync.RWMutex{},
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
Expand Down Expand Up @@ -620,14 +620,16 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
34 changes: 22 additions & 12 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type Session interface {
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCommandValue(byte)
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
SetSessionManager(util.SessionManager)
Expand Down Expand Up @@ -780,6 +780,10 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand All @@ -796,6 +800,10 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand Down Expand Up @@ -907,7 +915,7 @@ func (s *session) ParseSQL(ctx context.Context, sql, charset, collation string)
return s.parser.Parse(sql, charset, collation)
}

func (s *session) SetProcessInfo(sql string, t time.Time, command byte) {
func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
var db interface{}
if len(s.sessionVars.CurrentDB) > 0 {
db = s.sessionVars.CurrentDB
Expand All @@ -918,16 +926,17 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte) {
info = sql
}
pi := util.ProcessInfo{
ID: s.sessionVars.ConnectionID,
DB: db,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: info,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
ID: s.sessionVars.ConnectionID,
DB: db,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: info,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
MaxExecutionTime: maxExecutionTime,
}
if s.sessionVars.User != nil {
pi.User = s.sessionVars.User.Username
Expand Down Expand Up @@ -1632,6 +1641,7 @@ var builtinGlobalVariable = []string{
variable.WaitTimeout,
variable.InteractiveTimeout,
variable.MaxPreparedStmtCount,
variable.MaxExecutionTime,
/* TiDB specific global variables: */
variable.TiDBSkipUTF8Check,
variable.TiDBIndexJoinBatchSize,
Expand Down
Loading

0 comments on commit 3ad5ff6

Please sign in to comment.