Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: Add support for MAX_EXECUTION_TIME (#10541) (#10963) #11026

Merged
merged 14 commits into from
Jul 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import (

// processinfoSetter is the interface use to set current running process info.
type processinfoSetter interface {
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
}

// recordSet wraps an executor, implements sqlexec.RecordSet interface
Expand Down Expand Up @@ -245,8 +245,9 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
sql = ss.SecureText()
}
}
maxExecutionTime := getMaxExecutionTime(sctx, a.StmtNode)
// Update processinfo, ShowProcess() will use it.
pi.SetProcessInfo(sql, time.Now(), cmd)
pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime)
a.Ctx.GetSessionVars().StmtCtx.StmtType = GetStmtLabel(a.StmtNode)
}

Expand Down Expand Up @@ -285,6 +286,20 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
}, nil
}

// getMaxExecutionTime get the max execution timeout value.
func getMaxExecutionTime(sctx sessionctx.Context, stmtNode ast.StmtNode) uint64 {
ret := sctx.GetSessionVars().MaxExecutionTime
if sel, ok := stmtNode.(*ast.SelectStmt); ok {
for _, hint := range sel.TableHints {
if hint.HintName.L == variable.MaxExecutionTime {
ret = hint.MaxExecutionTime
break
}
}
}
return ret
}

type chunkRowRecordSet struct {
rows []chunk.Row
idx int
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ golang.org/x/tools v0.0.0-20190130214255-bb1329dc71a0/go.mod h1:n7NCudcB/nEzxVGm
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/genproto v0.0.0-20180608181217-32ee49c4dd80/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f h1:FU37niK8AQ59mHcskRyQL7H0ErSeNh650vdcj8HqdSI=
google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275 h1:9oFlwfEGIvmxXTcY53ygNyxIQtWciRHjrnUvZJCYXYU=
google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
Expand Down
22 changes: 19 additions & 3 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
Expand Down Expand Up @@ -262,6 +263,11 @@ func (cc *clientConn) readPacket() ([]byte, error) {
}

func (cc *clientConn) writePacket(data []byte) error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.writePacket(data)
}

Expand Down Expand Up @@ -845,7 +851,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
cc.lastCmd = string(hack.String(data))
token := cc.server.getToken()
defer func() {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep)
// if handleChangeUser failed, cc.ctx may be nil
if cc.ctx != nil {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep, 0)
}

cc.server.releaseToken(token)
span.Finish()
}()
Expand All @@ -860,9 +870,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
switch cmd {
case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset,
mysql.ComSetOption, mysql.ComChangeUser:
cc.ctx.SetProcessInfo("", t, cmd)
cc.ctx.SetProcessInfo("", t, cmd, 0)
case mysql.ComInitDB:
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd)
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0)
}

switch cmd {
Expand Down Expand Up @@ -925,6 +935,11 @@ func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
}

func (cc *clientConn) flush() error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.flush()
}

Expand Down Expand Up @@ -1255,6 +1270,7 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b
if err != nil {
return err
}

return cc.flush()
}

Expand Down
2 changes: 1 addition & 1 deletion server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err
if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok {
sql = prepared.sql
}
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute)
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0)
rs := stmt.GetResultSet()
if rs == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
Expand Down
77 changes: 77 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/util/arena"
)

type ConnTestSuite struct {
Expand Down Expand Up @@ -207,3 +211,76 @@ func mapBelong(m1, m2 map[string]string) bool {
}
return true
}

func (ts ConnTestSuite) TestConnExecutionTimeout(c *C) {
//There is no underlying netCon, use failpoint to avoid panic
c.Assert(failpoint.Enable("github.com/pingcap/tidb/server/FakeClientConn", "return(1)"), IsNil)

c.Parallel()
var err error
ts.store, err = mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
ts.dom, err = session.BootstrapSession(ts.store)
c.Assert(err, IsNil)
se, err := session.CreateSession4Test(ts.store)
c.Assert(err, IsNil)

connID := 1
se.SetConnectionID(uint64(connID))
tc := &TiDBContext{
session: se,
stmts: make(map[int]*TiDBStatement),
}
cc := &clientConn{
connectionID: uint32(connID),
server: &Server{
capability: defaultCapability,
},
ctx: tc,
alloc: arena.NewAllocator(32 * 1024),
}
srv := &Server{
clients: map[uint32]*clientConn{
uint32(connID): cc,
},
}
handle := ts.dom.ExpensiveQueryHandle().SetSessionManager(srv)
go handle.Run()
defer handle.Close()

_, err = se.Execute(context.Background(), "use test;")
c.Assert(err, IsNil)
_, err = se.Execute(context.Background(), "CREATE TABLE testTable2 (id bigint PRIMARY KEY, age int)")
c.Assert(err, IsNil)
for i := 0; i < 10; i++ {
str := fmt.Sprintf("insert into testTable2 values(%d, %d)", i, i%80)
_, err = se.Execute(context.Background(), str)
c.Assert(err, IsNil)
}

_, err = se.Execute(context.Background(), "select SLEEP(1);")
c.Assert(err, IsNil)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 500;")
c.Assert(err, IsNil)

now := time.Now()
err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(3);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) < 3*time.Second, IsTrue)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 0;")
c.Assert(err, IsNil)

now = time.Now()
err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(1);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) > 500*time.Millisecond, IsTrue)

now = time.Now()
err = cc.handleQuery(context.Background(), "select /*+ MAX_EXECUTION_TIME(100)*/ * FROM testTable2 WHERE SLEEP(3);")
c.Assert(err, IsNil)
c.Assert(time.Since(now) < 3*time.Second, IsTrue)

c.Assert(failpoint.Disable("github.com/pingcap/tidb/server/FakeClientConn"), IsNil)
}
2 changes: 1 addition & 1 deletion server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type QueryCtx interface {
// SetValue saves a value associated with this context for key.
SetValue(key fmt.Stringer, value interface{})

SetProcessInfo(sql string, t time.Time, command byte)
SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64)

// CommitTxn commits the transaction operations.
CommitTxn(ctx context.Context) error
Expand Down
4 changes: 2 additions & 2 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func (tc *TiDBContext) CommitTxn(ctx context.Context) error {
}

// SetProcessInfo implements QueryCtx SetProcessInfo method.
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte) {
tc.session.SetProcessInfo(sql, t, command)
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
tc.session.SetProcessInfo(sql, t, command, maxExecutionTime)
}

// RollbackTxn implements QueryCtx RollbackTxn method.
Expand Down
26 changes: 14 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand All @@ -106,7 +107,7 @@ type Server struct {
driver IDriver
listener net.Listener
socket net.Listener
rwlock *sync.RWMutex
rwlock sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
capability uint32
Expand Down Expand Up @@ -198,7 +199,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
cfg: cfg,
driver: driver,
concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
rwlock: &sync.RWMutex{},
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
Expand Down Expand Up @@ -620,14 +620,16 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
34 changes: 22 additions & 12 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type Session interface {
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCommandValue(byte)
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
SetSessionManager(util.SessionManager)
Expand Down Expand Up @@ -780,6 +780,10 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand All @@ -796,6 +800,10 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand Down Expand Up @@ -907,7 +915,7 @@ func (s *session) ParseSQL(ctx context.Context, sql, charset, collation string)
return s.parser.Parse(sql, charset, collation)
}

func (s *session) SetProcessInfo(sql string, t time.Time, command byte) {
func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
var db interface{}
if len(s.sessionVars.CurrentDB) > 0 {
db = s.sessionVars.CurrentDB
Expand All @@ -918,16 +926,17 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte) {
info = sql
}
pi := util.ProcessInfo{
ID: s.sessionVars.ConnectionID,
DB: db,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: info,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
ID: s.sessionVars.ConnectionID,
DB: db,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: info,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
MaxExecutionTime: maxExecutionTime,
}
if s.sessionVars.User != nil {
pi.User = s.sessionVars.User.Username
Expand Down Expand Up @@ -1632,6 +1641,7 @@ var builtinGlobalVariable = []string{
variable.WaitTimeout,
variable.InteractiveTimeout,
variable.MaxPreparedStmtCount,
variable.MaxExecutionTime,
/* TiDB specific global variables: */
variable.TiDBSkipUTF8Check,
variable.TiDBIndexJoinBatchSize,
Expand Down
Loading