diff --git a/.github/labeler.yml b/.github/labeler.yml index 5040b7cfc9da0..12a5b495d6a0e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,6 +1,7 @@ component/executor: - distsql/* - executor/* + - !executor/brie* - util/chunk/* - util/disk/* - util/execdetails/* @@ -31,3 +32,6 @@ component/DDL: component/config: - config/* + +sig/migrate: + - executor/brie* diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index cbc0a08852034..d0df1d6f54591 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -1541,6 +1541,28 @@ func (s *testSuite) TestCapturePlanBaselineIgnoreTiFlash(c *C) { c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t`") } +func (s *testSuite) TestSPMHitInfo(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"), IsTrue) + c.Assert(tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) + + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + func (s *testSuite) TestNotEvolvePlanForReadStorageHint(c *C) { tk := testkit.NewTestKit(c, s.store) s.cleanBindingEnv(tk) @@ -1843,3 +1865,17 @@ func (s *testSuite) TestCaptureWithZeroSlowLogThreshold(c *C) { c.Assert(len(rows), Equals, 1) c.Assert(rows[0][0], Equals, "select * from test . t") } + +func (s *testSuite) TestSPMWithoutUseDatabase(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk1 := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + s.cleanBindingEnv(tk1) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, key(a))") + tk.MustExec("create global binding for select * from t using select * from t force index(a)") + + err := tk1.ExecToErr("select * from t") + c.Assert(err, ErrorMatches, "*No database selected") +} diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 93d4b65391452..2a2355c7c2014 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/parser/format" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -123,17 +122,21 @@ func NewBindHandle(ctx sessionctx.Context) *BindHandle { func (h *BindHandle) Update(fullLoad bool) (err error) { h.bindInfo.Lock() lastUpdateTime := h.bindInfo.lastUpdateTime + updateTime := lastUpdateTime.String() + if fullLoad { + updateTime = "0000-00-00 00:00:00" + } - sql := "select original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source from mysql.bind_info" - if !fullLoad { - sql += " where update_time > \"" + lastUpdateTime.String() + "\"" + exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source + FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time`, updateTime) + if err != nil { + return err } - // We need to apply the updates by order, wrong apply order of same original sql may cause inconsistent state. - sql += " order by update_time" - // No need to acquire the session context lock for ExecRestrictedSQL, it + // No need to acquire the session context lock for ExecRestrictedStmt, it // uses another background session. - rows, _, err := h.sctx.Context.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) if err != nil { h.bindInfo.Unlock() return err @@ -215,7 +218,7 @@ func (h *BindHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRecor return err } // Binding recreation should physically delete previous bindings. - _, err = exec.ExecuteInternal(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db, "")) + _, err = exec.ExecuteInternal(context.TODO(), `DELETE FROM mysql.bind_info WHERE original_sql = %?`, record.OriginalSQL) if err != nil { return err } @@ -227,7 +230,17 @@ func (h *BindHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRecor record.Bindings[i].UpdateTime = now // Insert the BindRecord to the storage. - _, err = exec.ExecuteInternal(context.TODO(), h.insertBindInfoSQL(record.OriginalSQL, record.Db, record.Bindings[i])) + _, err = exec.ExecuteInternal(context.TODO(), `INSERT INTO mysql.bind_info VALUES (%?,%?, %?, %?, %?, %?, %?, %?, %?)`, + record.OriginalSQL, + record.Bindings[i].BindSQL, + record.Db, + record.Bindings[i].Status, + record.Bindings[i].CreateTime.String(), + record.Bindings[i].UpdateTime.String(), + record.Bindings[i].Charset, + record.Bindings[i].Collation, + record.Bindings[i].Source, + ) if err != nil { return err } @@ -289,7 +302,7 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, record *BindRecord) return err } if duplicateBinding != nil { - _, err = exec.ExecuteInternal(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db, duplicateBinding.BindSQL)) + _, err = exec.ExecuteInternal(context.TODO(), `DELETE FROM mysql.bind_info WHERE original_sql = %? AND bind_sql = %?`, record.OriginalSQL, duplicateBinding.BindSQL) if err != nil { return err } @@ -305,7 +318,17 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, record *BindRecord) record.Bindings[i].UpdateTime = now // Insert the BindRecord to the storage. - _, err = exec.ExecuteInternal(context.TODO(), h.insertBindInfoSQL(record.OriginalSQL, record.Db, record.Bindings[i])) + _, err = exec.ExecuteInternal(context.TODO(), `INSERT INTO mysql.bind_info VALUES (%?, %?, %?, %?, %?, %?, %?, %?, %?)`, + record.OriginalSQL, + record.Bindings[i].BindSQL, + record.Db, + record.Bindings[i].Status, + record.Bindings[i].CreateTime.String(), + record.Bindings[i].UpdateTime.String(), + record.Bindings[i].Charset, + record.Bindings[i].Collation, + record.Bindings[i].Source, + ) if err != nil { return err } @@ -349,17 +372,19 @@ func (h *BindHandle) DropBindRecord(originalSQL, db string, binding *Binding) (e // Lock mysql.bind_info to synchronize with CreateBindRecord / AddBindRecord / DropBindRecord on other tidb instances. if err = h.lockBindInfoTable(); err != nil { - return + return err } - updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) + updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3).String() - bindSQL := "" - if binding != nil { - bindSQL = binding.BindSQL + if binding == nil { + _, err = exec.ExecuteInternal(context.TODO(), `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %?`, + deleted, updateTs, originalSQL, updateTs) + } else { + _, err = exec.ExecuteInternal(context.TODO(), `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %? AND bind_sql = %?`, + deleted, updateTs, originalSQL, updateTs, binding.BindSQL) } - _, err = exec.ExecuteInternal(context.TODO(), h.logicalDeleteBindInfoSQL(originalSQL, db, updateTs, bindSQL)) deleteRows = int(h.sctx.Context.GetSessionVars().StmtCtx.AffectedRows()) return err } @@ -575,49 +600,13 @@ func (c cache) getBindRecord(hash, normdOrigSQL, db string) *BindRecord { return nil } -func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db, bindSQL string) string { - sql := fmt.Sprintf( - `DELETE FROM mysql.bind_info WHERE original_sql=%s`, - expression.Quote(normdOrigSQL), - ) - if bindSQL == "" { - return sql - } - return sql + fmt.Sprintf(` and bind_sql = %s`, expression.Quote(bindSQL)) -} - -func (h *BindHandle) insertBindInfoSQL(orignalSQL string, db string, info Binding) string { - return fmt.Sprintf(`INSERT INTO mysql.bind_info VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)`, - expression.Quote(orignalSQL), - expression.Quote(info.BindSQL), - expression.Quote(db), - expression.Quote(info.Status), - expression.Quote(info.CreateTime.String()), - expression.Quote(info.UpdateTime.String()), - expression.Quote(info.Charset), - expression.Quote(info.Collation), - expression.Quote(info.Source), - ) -} - // LockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction. func (h *BindHandle) LockBindInfoSQL() string { - return fmt.Sprintf("UPDATE mysql.bind_info SET source=%s WHERE original_sql=%s", - expression.Quote(Builtin), - expression.Quote(BuiltinPseudoSQL4BindLock)) -} - -func (h *BindHandle) logicalDeleteBindInfoSQL(originalSQL, db string, updateTs types.Time, bindingSQL string) string { - updateTsStr := updateTs.String() - sql := fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and update_time<%s`, - expression.Quote(deleted), - expression.Quote(updateTsStr), - expression.Quote(originalSQL), - expression.Quote(updateTsStr)) - if bindingSQL == "" { - return sql + sql, err := sqlexec.EscapeSQL("UPDATE mysql.bind_info SET source= %? WHERE original_sql= %?", Builtin, BuiltinPseudoSQL4BindLock) + if err != nil { + return "" } - return sql + fmt.Sprintf(` and bind_sql = %s`, expression.Quote(bindingSQL)) + return sql } // CaptureBaselines is used to automatically capture plan baselines. @@ -661,16 +650,14 @@ func (h *BindHandle) CaptureBaselines() { func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { oriVals := sctx.GetSessionVars().UsePlanBaselines sctx.GetSessionVars().UsePlanBaselines = false - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) sctx.GetSessionVars().UsePlanBaselines = oriVals - if len(recordSets) > 0 { - defer terror.Log(recordSets[0].Close()) - } if err != nil { return "", err } - chk := recordSets[0].NewChunk() - err = recordSets[0].Next(context.TODO(), chk) + defer terror.Call(rs.Close) + chk := rs.NewChunk() + err = rs.Next(context.TODO(), chk) if err != nil { return "", err } @@ -766,9 +753,17 @@ func (h *BindHandle) SaveEvolveTasksToStore() { } func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) { - sql := fmt.Sprintf("select variable_name, variable_value from mysql.global_variables where variable_name in ('%s', '%s', '%s')", - variable.TiDBEvolvePlanTaskMaxTime, variable.TiDBEvolvePlanTaskStartTime, variable.TiDBEvolvePlanTaskEndTime) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams( + context.TODO(), + "SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)", + variable.TiDBEvolvePlanTaskMaxTime, + variable.TiDBEvolvePlanTaskStartTime, + variable.TiDBEvolvePlanTaskEndTime, + ) + if err != nil { + return 0, time.Time{}, time.Time{}, err + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return 0, time.Time{}, time.Time{}, err } @@ -839,7 +834,7 @@ func (h *BindHandle) getOnePendingVerifyJob() (string, string, Binding) { func (h *BindHandle) getRunningDuration(sctx sessionctx.Context, db, sql string, maxTime time.Duration) (time.Duration, error) { ctx := context.TODO() if db != "" { - _, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, fmt.Sprintf("use `%s`", db)) + _, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "use %n", db) if err != nil { return 0, err } @@ -873,23 +868,20 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan resultChan <- fmt.Errorf("run sql panicked: %v", string(buf)) } }() - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { - if len(recordSets) > 0 { - terror.Call(recordSets[0].Close) - } + terror.Call(rs.Close) resultChan <- err return } - recordSet := recordSets[0] - chk := recordSets[0].NewChunk() + chk := rs.NewChunk() for { - err = recordSet.Next(ctx, chk) + err = rs.Next(ctx, chk) if err != nil || chk.NumRows() == 0 { break } } - terror.Call(recordSets[0].Close) + terror.Call(rs.Close) resultChan <- err } diff --git a/cmd/explaintest/r/partition_pruning.result b/cmd/explaintest/r/partition_pruning.result index d8d024a551dd6..921ed0012593f 100644 --- a/cmd/explaintest/r/partition_pruning.result +++ b/cmd/explaintest/r/partition_pruning.result @@ -3859,16 +3859,9 @@ TableReader_8 10.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t2 where a=0; id estRows task access object operator info -PartitionUnion_9 30.00 root -├─TableReader_12 10.00 root data:Selection_11 -│ └─Selection_11 10.00 cop[tikv] eq(test.t2.a, 0) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 10.00 root data:Selection_14 -│ └─Selection_14 10.00 cop[tikv] eq(test.t2.a, 0) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 10.00 root data:Selection_17 - └─Selection_17 10.00 cop[tikv] eq(test.t2.a, 0) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 10.00 root data:Selection_7 +└─Selection_7 10.00 cop[tikv] eq(test.t2.a, 0) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo explain select * from t1 where a=0xFE; id estRows task access object operator info TableReader_8 10.00 root data:Selection_7 @@ -3876,31 +3869,17 @@ TableReader_8 10.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo explain select * from t2 where a=0xFE; id estRows task access object operator info -PartitionUnion_9 30.00 root -├─TableReader_12 10.00 root data:Selection_11 -│ └─Selection_11 10.00 cop[tikv] eq(test.t2.a, 254) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 10.00 root data:Selection_14 -│ └─Selection_14 10.00 cop[tikv] eq(test.t2.a, 254) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 10.00 root data:Selection_17 - └─Selection_17 10.00 cop[tikv] eq(test.t2.a, 254) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 10.00 root data:Selection_7 +└─Selection_7 10.00 cop[tikv] eq(test.t2.a, 254) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a > 0xFE AND a <= 0xFF; id estRows task access object operator info TableDual_6 0.00 root rows:0 explain select * from t2 where a > 0xFE AND a <= 0xFF; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a >= 0xFE AND a <= 0xFF; id estRows task access object operator info TableReader_8 250.00 root data:Selection_7 @@ -3908,16 +3887,9 @@ TableReader_8 250.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo explain select * from t2 where a >= 0xFE AND a <= 0xFF; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a < 64 AND a >= 63; id estRows task access object operator info TableReader_8 250.00 root data:Selection_7 @@ -3925,16 +3897,13 @@ TableReader_8 250.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t2 where a < 64 AND a >= 63; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +PartitionUnion_8 500.00 root +├─TableReader_11 250.00 root data:Selection_10 +│ └─Selection_10 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) +│ └─TableFullScan_9 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo +└─TableReader_14 250.00 root data:Selection_13 + └─Selection_13 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) + └─TableFullScan_12 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo explain select * from t1 where a <= 64 AND a >= 63; id estRows task access object operator info PartitionUnion_8 500.00 root @@ -3946,16 +3915,13 @@ PartitionUnion_8 500.00 root └─TableFullScan_12 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo explain select * from t2 where a <= 64 AND a >= 63; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +PartitionUnion_8 500.00 root +├─TableReader_11 250.00 root data:Selection_10 +│ └─Selection_10 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) +│ └─TableFullScan_9 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo +└─TableReader_14 250.00 root data:Selection_13 + └─Selection_13 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) + └─TableFullScan_12 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo drop table t1; drop table t2; create table t1(a bigint unsigned not null) partition by range(a+0) ( @@ -3969,35 +3935,13 @@ insert into t1 values (9),(19),(0xFFFF0000FFFF000-1), (0xFFFF0000FFFFFFF-1); explain select * from t1 where a >= 2305561538531885056-10 and a <= 2305561538531885056-8; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo explain select * from t1 where a > 0xFFFFFFFFFFFFFFEC and a < 0xFFFFFFFFFFFFFFEE; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableDual_6 0.00 root rows:0 explain select * from t1 where a>=0 and a <= 0xFFFFFFFFFFFFFFFF; id estRows task access object operator info PartitionUnion_10 13293.33 root @@ -4023,19 +3967,9 @@ partition p4 values less than (1000) insert into t1 values (-15),(-5),(5),(15),(-15),(-5),(5),(15); explain select * from t1 where a>-2 and a <=0; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo drop table t1; CREATE TABLE t1 ( recdate DATETIME NOT NULL ) PARTITION BY RANGE( TO_DAYS(recdate) ) ( @@ -4086,28 +4020,14 @@ partition p2 values less than (255) insert into t1 select A.a + 10*B.a from t0 A, t0 B; explain select * from t1 where a between 10 and 13; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) - └─TableFullScan_16 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t1 where a between 10 and 10+33; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) - └─TableFullScan_16 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo drop table t0, t1; drop table if exists t; create table t(a timestamp) partition by range(unix_timestamp(a)) (partition p0 values less than(unix_timestamp('2019-02-16 14:20:00')), partition p1 values less than (maxvalue)); diff --git a/ddl/column.go b/ddl/column.go index d264f669c1f37..42e5de30efa4f 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -14,6 +14,7 @@ package ddl import ( + "context" "fmt" "math/bits" "strings" @@ -538,16 +539,25 @@ func checkAndApplyNewAutoRandomBits(job *model.Job, t *meta.Meta, tblInfo *model // checkForNullValue ensure there are no null values of the column of this table. // `isDataTruncated` indicates whether the new field and the old field type are the same, in order to be compatible with mysql. func checkForNullValue(ctx sessionctx.Context, isDataTruncated bool, schema, table, newCol model.CIStr, oldCols ...*model.ColumnInfo) error { - colsStr := "" + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + paramsList := make([]interface{}, 0, 2+len(oldCols)) + paramsList = append(paramsList, schema.L, table.L) for i, col := range oldCols { if i == 0 { - colsStr += "`" + col.Name.L + "` is null" + buf.WriteString("%n is null") + paramsList = append(paramsList, col.Name.L) } else { - colsStr += " or `" + col.Name.L + "` is null" + buf.WriteString(" or %n is null") + paramsList = append(paramsList, col.Name.L) } } - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s limit 1;", schema.L, table.L, colsStr) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + buf.WriteString(" limit 1") + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), buf.String(), paramsList...) + if err != nil { + return errors.Trace(err) + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } diff --git a/ddl/db_test.go b/ddl/db_test.go index 262677eeca2a4..c8afb966a8714 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -4733,7 +4733,7 @@ func (s *testSerialDBSuite) TestAddIndexFailOnCaseWhenCanExit(c *C) { tk.MustExec("insert into t values(1, 1)") _, err := tk.Exec("alter table t add index idx(b)") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: job.ErrCount:512, mock unknown type: ast.whenClause.") tk.MustExec("drop table if exists t") } @@ -4988,3 +4988,74 @@ func (s *testSerialDBSuite) TestColumnTypeChangeIgnoreDisplayLength(c *C) { tk.MustExec("alter table t modify column a bigint(1)") tk.MustExec("drop table if exists t") } + +// Close issue #23202 +func (s *testSerialDBSuite) TestDDLExitWhenCancelMeetPanic(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values(1,1),(2,2)") + tk.MustExec("alter table t add index(b)") + tk.MustExec("set @@global.tidb_ddl_error_count_limit=3") + + failpoint.Enable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit", `return(true)`) + defer func() { + failpoint.Disable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit") + }() + + originalHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(originalHook) + + hook := &ddl.TestDDLCallback{Do: s.dom} + var jobID int64 + hook.OnJobRunBeforeExported = func(job *model.Job) { + if jobID != 0 { + return + } + if job.Type == model.ActionDropIndex { + jobID = job.ID + } + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + + // when it panics in write-reorg state, the job will be pulled up as a cancelling job. Since drop-index with + // write-reorg can't be cancelled, so it will be converted to running state and try again (dead loop). + _, err := tk.Exec("alter table t drop index b") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:-1]panic in handling DDL logic and error count beyond the limitation 3, cancelled") + c.Assert(jobID > 0, Equals, true) + + // Verification of the history job state. + var job *model.Job + err = kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { + t := meta.NewMeta(txn) + var err1 error + job, err1 = t.GetHistoryDDLJob(jobID) + return errors.Trace(err1) + }) + c.Assert(err, IsNil) + c.Assert(job.ErrorCount, Equals, int64(4)) + c.Assert(job.Error.Error(), Equals, "[ddl:-1]panic in handling DDL logic and error count beyond the limitation 3, cancelled") +} + +// Close issue #23321. +// See https://github.com/pingcap/tidb/issues/23321 +func (s *testSerialDBSuite) TestJsonUnmarshalErrWhenPanicInCancellingPath(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists test_add_index_after_add_col") + tk.MustExec("create table test_add_index_after_add_col(a int, b int not null default '0');") + tk.MustExec("insert into test_add_index_after_add_col values(1, 2),(2,2);") + tk.MustExec("alter table test_add_index_after_add_col add column c int not null default '0';") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit"), IsNil) + }() + + _, err := tk.Exec("alter table test_add_index_after_add_col add unique index cc(c);") + c.Assert(err.Error(), Equals, "[kv:1062]DDL job cancelled by panic in rollingback, error msg: Duplicate entry '0' for key 'cc'") +} diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index f21a431346cb1..5e12df0920af8 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -3716,11 +3716,8 @@ func (d *ddl) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { } return errors.Trace(err) } - oldTblInfo := tb.Meta() - if oldTblInfo.PreSplitRegions > 0 { - if _, tb, err := d.getSchemaAndTableByIdent(ctx, ti); err == nil { - d.preSplitAndScatter(ctx, tb.Meta(), tb.Meta().GetPartitionInfo()) - } + if _, tb, err := d.getSchemaAndTableByIdent(ctx, ti); err == nil { + d.preSplitAndScatter(ctx, tb.Meta(), tb.Meta().GetPartitionInfo()) } if !config.TableLockEnabled() { diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index 00280bbdd0b6a..0980a0e159391 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -582,12 +582,62 @@ func chooseLeaseTime(t, max time.Duration) time.Duration { return t } +// countForPanic records the error count for DDL job. +func (w *worker) countForPanic(job *model.Job) { + // If run DDL job panic, just cancel the DDL jobs. + if job.State == model.JobStateRollingback { + job.State = model.JobStateCancelled + msg := fmt.Sprintf("DDL job cancelled by panic in rollingback, error msg: %s", terror.ToSQLError(job.Error).Message) + job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), msg) + logutil.Logger(w.logCtx).Warn(msg) + return + } + job.State = model.JobStateCancelling + job.ErrorCount++ + + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + + if job.ErrorCount > errorCount { + msg := fmt.Sprintf("panic in handling DDL logic and error count beyond the limitation %d, cancelled", errorCount) + logutil.Logger(w.logCtx).Warn(msg) + job.Error = toTError(errors.New(msg)) + job.State = model.JobStateCancelled + } +} + +// countForError records the error count for DDL job. +func (w *worker) countForError(err error, job *model.Job) error { + job.Error = toTError(err) + job.ErrorCount++ + + // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. + if job.State == model.JobStateCancelled { + logutil.Logger(w.logCtx).Info("[ddl] DDL job is cancelled normally", zap.Error(err)) + return nil + } + logutil.Logger(w.logCtx).Error("[ddl] run DDL job error", zap.Error(err)) + + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + // Check error limit to avoid falling into an infinite loop. + if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && admin.IsJobRollbackable(job) { + logutil.Logger(w.logCtx).Warn("[ddl] DDL job error count exceed the limit, cancelling it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) + job.State = model.JobStateCancelling + } + return err +} + // runDDLJob runs a DDL job. It returns the current schema version in this transaction and the error. func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { defer tidbutil.Recover(metrics.LabelDDLWorker, fmt.Sprintf("%s runDDLJob", w), func() { - // If run DDL job panic, just cancel the DDL jobs. - job.State = model.JobStateCancelling + w.countForPanic(job) }, false) // Mock for run ddl job panic. @@ -690,27 +740,9 @@ func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err = errInvalidDDLJob.GenWithStack("invalid ddl job type: %v", job.Type) } - // Save errors in job, so that others can know errors happened. + // Save errors in job if any, so that others can know errors happened. if err != nil { - job.Error = toTError(err) - job.ErrorCount++ - - // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. - if job.State == model.JobStateCancelled { - logutil.Logger(w.logCtx).Info("[ddl] DDL job is cancelled normally", zap.Error(err)) - return ver, nil - } - logutil.Logger(w.logCtx).Error("[ddl] run DDL job error", zap.Error(err)) - - // Load global ddl variables. - if err1 := loadDDLVars(w); err1 != nil { - logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) - } - // Check error limit to avoid falling into an infinite loop. - if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && admin.IsJobRollbackable(job) { - logutil.Logger(w.logCtx).Warn("[ddl] DDL job error count exceed the limit, cancelling it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) - job.State = model.JobStateCancelling - } + err = w.countForError(err, job) } return } diff --git a/ddl/delete_range.go b/ddl/delete_range.go index 45059716a824d..bf06b6392fea6 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -16,8 +16,8 @@ package ddl import ( "context" "encoding/hex" - "fmt" "math" + "strings" "sync" "sync/atomic" @@ -35,7 +35,7 @@ import ( const ( insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` - insertDeleteRangeSQLValue = `("%d", "%d", "%s", "%s", "%d")` + insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` insertDeleteRangeSQL = insertDeleteRangeSQLPrefix + insertDeleteRangeSQLValue delBatchSize = 65536 @@ -350,25 +350,27 @@ func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, end logutil.BgLogger().Info("[ddl] insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64("elementID", elementID)) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql := fmt.Sprintf(insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) return errors.Trace(err) } func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) - sql := insertDeleteRangeSQLPrefix + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + paramsList := make([]interface{}, 0, len(tableIDs)*5) for i, tableID := range tableIDs { startKey := tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) + buf.WriteString(insertDeleteRangeSQLValue) if i != len(tableIDs)-1 { - sql += "," + buf.WriteString(",") } + paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } diff --git a/ddl/index.go b/ddl/index.go index a05a125d9af94..e6cb693cd1c3c 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -618,6 +618,12 @@ func onDropIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { // Set column index flag. dropIndexColumnFlag(tblInfo, indexInfo) + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + if val.(bool) { + panic("panic test in cancelling add index") + } + }) + tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-len(dependentHiddenCols)] ver, err = updateVersionAndTableInfo(t, job, tblInfo, originalState != model.StateNone) diff --git a/ddl/reorg.go b/ddl/reorg.go index 1eb68a822a265..2d5a9dd7b4c31 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -198,8 +198,12 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { if !ok { return statistics.PseudoRowCount } - sql := fmt.Sprintf("select table_rows from information_schema.tables where tidb_table_id=%v;", tblInfo.ID) - rows, _, err := executor.ExecRestrictedSQL(sql) + sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" + stmt, err := executor.ParseWithParams(context.Background(), sql, tblInfo.ID) + if err != nil { + return statistics.PseudoRowCount + } + rows, _, err := executor.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return statistics.PseudoRowCount } diff --git a/ddl/rollingback.go b/ddl/rollingback.go index adb0355cc7395..4e205ad7cbfdf 100644 --- a/ddl/rollingback.go +++ b/ddl/rollingback.go @@ -14,11 +14,16 @@ package ddl import ( + "fmt" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -37,8 +42,11 @@ func updateColsNull2NotNull(tblInfo *model.TableInfo, indexInfo *model.IndexInfo } func convertAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, indexInfo *model.IndexInfo, err error) (int64, error) { - job.State = model.JobStateRollingback - + failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) + } + }) if indexInfo.Primary { nullCols, err := getNullColInfos(tblInfo, indexInfo) if err != nil { @@ -65,7 +73,7 @@ func convertAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, tblInfo *model.T if err1 != nil { return ver, errors.Trace(err1) } - + job.State = model.JobStateRollingback return ver, errors.Trace(err) } @@ -99,7 +107,6 @@ func convertNotStartAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, occuredE } func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { - job.State = model.JobStateRollingback tblInfo, columnInfo, col, _, _, err := checkAddColumn(t, job) if err != nil { return ver, errors.Trace(err) @@ -118,11 +125,13 @@ func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { if err != nil { return ver, errors.Trace(err) } + + job.State = model.JobStateRollingback return ver, errCancelledDDLJob } func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, colInfo, err := checkDropColumn(t, job) + _, colInfo, err := checkDropColumn(t, job) if err != nil { return ver, errors.Trace(err) } @@ -130,7 +139,6 @@ func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) // StatePublic means when the job is not running yet. if colInfo.State == model.StatePublic { job.State = model.JobStateCancelled - job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) return ver, errCancelledDDLJob } // In the state of drop column `write only -> delete only -> reorganization`, @@ -140,12 +148,11 @@ func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) } func rollingbackDropIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, indexInfo, err := checkDropIndex(t, job) + _, indexInfo, err := checkDropIndex(t, job) if err != nil { return ver, errors.Trace(err) } - originalState := indexInfo.State switch indexInfo.State { case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: // We can not rollback now, so just continue to drop index. @@ -153,20 +160,11 @@ func rollingbackDropIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { job.State = model.JobStateRunning return ver, nil case model.StatePublic: - job.State = model.JobStateRollbackDone - indexInfo.State = model.StatePublic + job.State = model.JobStateCancelled + return ver, errCancelledDDLJob default: return ver, ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo.State) } - - job.SchemaState = indexInfo.State - job.Args = []interface{}{indexInfo.Name} - ver, err = updateVersionAndTableInfo(t, job, tblInfo, originalState != indexInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) - return ver, errCancelledDDLJob } func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { @@ -184,7 +182,6 @@ func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isP } func convertAddTablePartitionJob2RollbackJob(t *meta.Meta, job *model.Job, otherwiseErr error, tblInfo *model.TableInfo) (ver int64, err error) { - job.State = model.JobStateRollingback addingDefinitions := tblInfo.Partition.AddingDefinitions partNames := make([]string, 0, len(addingDefinitions)) for _, pd := range addingDefinitions { @@ -195,6 +192,7 @@ func convertAddTablePartitionJob2RollbackJob(t *meta.Meta, job *model.Job, other if err != nil { return ver, errors.Trace(err) } + job.State = model.JobStateRollingback return ver, errors.Trace(otherwiseErr) } @@ -321,14 +319,41 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) } if err != nil { + if job.Error == nil { + job.Error = toTError(err) + } + job.ErrorCount++ + + if errCancelledDDLJob.Equal(err) { + // The job is normally cancelled. + if !job.Error.Equal(errCancelledDDLJob) { + job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), + fmt.Sprintf("DDL job rollback, error msg: %s", terror.ToSQLError(job.Error).Message)) + } + } else { + // A job canceling meet other error. + // + // Once `convertJob2RollbackJob` meets an error, the job state can't be set as `JobStateRollingback` since + // job state and args may not be correctly overwritten. The job will be fetched to run with the cancelling + // state again. So we should check the error count here. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + if job.ErrorCount > errorCount { + logutil.Logger(w.logCtx).Warn("[ddl] rollback DDL job error count exceed the limit, cancelled it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", errorCount)) + job.Error = toTError(errors.Errorf("rollback DDL job error count exceed the limit %d, cancelled it now", errorCount)) + job.State = model.JobStateCancelled + } + } + if job.State != model.JobStateRollingback && job.State != model.JobStateCancelled { logutil.Logger(w.logCtx).Error("[ddl] run DDL job failed", zap.String("job", job.String()), zap.Error(err)) } else { logutil.Logger(w.logCtx).Info("[ddl] the DDL job is cancelled normally", zap.String("job", job.String()), zap.Error(err)) + // If job is cancelled, we shouldn't return an error. + return ver, nil } - - job.Error = toTError(err) - job.ErrorCount++ } return } diff --git a/ddl/rollingback_test.go b/ddl/rollingback_test.go new file mode 100644 index 0000000000000..c7a188ecbd012 --- /dev/null +++ b/ddl/rollingback_test.go @@ -0,0 +1,101 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl_test + +import ( + "context" + "strconv" + + . "github.com/pingcap/check" + errors2 "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/testkit" +) + +var _ = SerialSuites(&testRollingBackSuite{&testDBSuite{}}) + +type testRollingBackSuite struct{ *testDBSuite } + +// TestCancelJobMeetError is used to test canceling ddl job failure when convert ddl job to a rollingback job. +func (s *testRollingBackSuite) TestCancelAddIndexJobError(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk1 := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk1.MustExec("use test") + + tk.MustExec("create table t_cancel_add_index (a int)") + tk.MustExec("insert into t_cancel_add_index values(1),(2),(3)") + tk.MustExec("set @@global.tidb_ddl_error_count_limit=3") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockConvertAddIdxJob2RollbackJobError", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockConvertAddIdxJob2RollbackJobError"), IsNil) + }() + + tbl := testGetTableByName(c, tk.Se, "test", "t_cancel_add_index") + c.Assert(tbl, NotNil) + + d := s.dom.DDL() + hook := &ddl.TestDDLCallback{Do: s.dom} + var ( + checkErr error + jobID int64 + res sqlexec.RecordSet + ) + hook.OnJobUpdatedExported = func(job *model.Job) { + if job.TableID != tbl.Meta().ID { + return + } + if job.Type != model.ActionAddIndex { + return + } + if job.SchemaState == model.StateDeleteOnly { + jobID = job.ID + res, checkErr = tk1.Exec("admin cancel ddl jobs " + strconv.Itoa(int(job.ID))) + // drain the result set here, otherwise the cancel action won't take effect immediately. + chk := res.NewChunk() + if err := res.Next(context.Background(), chk); err != nil { + checkErr = err + return + } + if err := res.Close(); err != nil { + checkErr = err + } + } + } + d.(ddl.DDLForTest).SetHook(hook) + + // This will hang on stateDeleteOnly, and the job will be canceled. + _, err := tk.Exec("alter table t_cancel_add_index add index idx(a)") + c.Assert(err, NotNil) + c.Assert(checkErr, IsNil) + c.Assert(err.Error(), Equals, "[ddl:-1]rollback DDL job error count exceed the limit 3, cancelled it now") + + // Verification of the history job state. + var job *model.Job + err = kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { + t := meta.NewMeta(txn) + var err1 error + job, err1 = t.GetHistoryDDLJob(jobID) + return errors2.Trace(err1) + }) + c.Assert(err, IsNil) + c.Assert(job.ErrorCount, Equals, int64(4)) + c.Assert(job.Error.Error(), Equals, "[ddl:-1]rollback DDL job error count exceed the limit 3, cancelled it now") +} diff --git a/ddl/serial_test.go b/ddl/serial_test.go index 684e795482c4b..ea2e94f347873 100644 --- a/ddl/serial_test.go +++ b/ddl/serial_test.go @@ -745,7 +745,7 @@ func (s *testSerialSuite) TestCancelJobByErrorCountLimit(c *C) { _, err = tk.Exec("create table t (a int)") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: mock do job error") } func (s *testSerialSuite) TestTruncateTableUpdateSchemaVersionErr(c *C) { @@ -763,7 +763,7 @@ func (s *testSerialSuite) TestTruncateTableUpdateSchemaVersionErr(c *C) { tk.MustExec("create table t (a int)") _, err = tk.Exec("truncate table t") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: mock update version error") // Disable fail point. c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockTruncateTableUpdateVersionError"), IsNil) tk.MustExec("truncate table t") diff --git a/ddl/util/util.go b/ddl/util/util.go index 69a6d3ced56f7..00c0f0e0aa021 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -14,12 +14,10 @@ package util import ( - "bytes" + "strings" + "context" "encoding/hex" - "fmt" - "strconv" - "github.com/pingcap/errors" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" @@ -32,12 +30,13 @@ import ( const ( deleteRangesTable = `gc_delete_range` doneDeleteRangesTable = `gc_delete_range_done` - loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%s WHERE ts < %v` - recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %d AND element_id = %d` - completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %d AND element_id = %d` - completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %d AND element_id in (%v)` - updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = "%s" WHERE job_id = %d AND element_id = %d AND start_key = "%s"` - deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %d AND element_id = %d` + loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%n WHERE ts < %?` + recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id in (` // + idList + ")" + updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = %? WHERE job_id = %? AND element_id = %? AND start_key = %?` + deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %? AND element_id = %?` + loadGlobalVars = `SELECT HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (%?)` ) // DelRangeTask is for run delete-range command in gc_worker. @@ -62,16 +61,14 @@ func LoadDoneDeleteRanges(ctx sessionctx.Context, safePoint uint64) (ranges []De } func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint uint64) (ranges []DelRangeTask, _ error) { - sql := fmt.Sprintf(loadDeleteRangeSQL, table, safePoint) - rss, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rss) > 0 { - defer terror.Call(rss[0].Close) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), loadDeleteRangeSQL, table, safePoint) + if rs != nil { + defer terror.Call(rs.Close) } if err != nil { return nil, errors.Trace(err) } - rs := rss[0] req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) for { @@ -106,8 +103,7 @@ func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint u // CompleteDeleteRange moves a record from gc_delete_range table to gc_delete_range_done table. // NOTE: This function WILL NOT start and run in a new transaction internally. func CompleteDeleteRange(ctx sessionctx.Context, dr DelRangeTask) error { - sql := fmt.Sprintf(recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) if err != nil { return errors.Trace(err) } @@ -117,29 +113,31 @@ func CompleteDeleteRange(ctx sessionctx.Context, dr DelRangeTask) error { // RemoveFromGCDeleteRange is exported for ddl pkg to use. func RemoveFromGCDeleteRange(ctx sessionctx.Context, jobID, elementID int64) error { - sql := fmt.Sprintf(completeDeleteRangeSQL, jobID, elementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), completeDeleteRangeSQL, jobID, elementID) return errors.Trace(err) } // RemoveMultiFromGCDeleteRange is exported for ddl pkg to use. func RemoveMultiFromGCDeleteRange(ctx sessionctx.Context, jobID int64, elementIDs []int64) error { - var buf bytes.Buffer + var buf strings.Builder + buf.WriteString(completeDeleteMultiRangesSQL) + paramIDs := make([]interface{}, 0, 1+len(elementIDs)) + paramIDs = append(paramIDs, jobID) for i, elementID := range elementIDs { if i > 0 { buf.WriteString(", ") } - buf.WriteString(strconv.FormatInt(elementID, 10)) + buf.WriteString("%?") + paramIDs = append(paramIDs, elementID) } - sql := fmt.Sprintf(completeDeleteMultiRangesSQL, jobID, buf.String()) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + buf.WriteString(")") + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), buf.String(), paramIDs...) return errors.Trace(err) } // DeleteDoneRecord removes a record from gc_delete_range_done table. func DeleteDoneRecord(ctx sessionctx.Context, dr DelRangeTask) error { - sql := fmt.Sprintf(deleteDoneRecordSQL, dr.JobID, dr.ElementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), deleteDoneRecordSQL, dr.JobID, dr.ElementID) return errors.Trace(err) } @@ -147,8 +145,7 @@ func DeleteDoneRecord(ctx sessionctx.Context, dr DelRangeTask) error { func UpdateDeleteRange(ctx sessionctx.Context, dr DelRangeTask, newStartKey, oldStartKey kv.Key) error { newStartKeyHex := hex.EncodeToString(newStartKey) oldStartKeyHex := hex.EncodeToString(oldStartKey) - sql := fmt.Sprintf(updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) return errors.Trace(err) } @@ -162,20 +159,14 @@ func LoadDDLVars(ctx sessionctx.Context) error { return LoadGlobalVars(ctx, []string{variable.TiDBDDLErrorCountLimit}) } -const loadGlobalVarsSQL = "select HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (%s)" - // LoadGlobalVars loads global variable from mysql.global_variables. func LoadGlobalVars(ctx sessionctx.Context, varNames []string) error { if sctx, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { - nameList := "" - for i, name := range varNames { - if i > 0 { - nameList += ", " - } - nameList += fmt.Sprintf("'%s'", name) + stmt, err := sctx.ParseWithParams(context.Background(), loadGlobalVars, varNames) + if err != nil { + return errors.Trace(err) } - sql := fmt.Sprintf(loadGlobalVarsSQL, nameList) - rows, _, err := sctx.ExecRestrictedSQL(sql) + rows, _, err := sctx.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } diff --git a/distsql/select_result.go b/distsql/select_result.go index 6952a9e792f6f..825882e2887f6 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -277,6 +277,10 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *tikv } r.stats.mergeCopRuntimeStats(copStats, respTime) + if copStats.ScanDetail != nil && len(r.copPlanIDs) > 0 { + r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], copStats.ScanDetail) + } + for i, detail := range r.selectResp.GetExecutionSummaries() { if detail != nil && detail.TimeProcessedNs != nil && detail.NumProducedRows != nil && detail.NumIterations != nil { @@ -338,13 +342,17 @@ type selectResultRuntimeStats struct { func (s *selectResultRuntimeStats) mergeCopRuntimeStats(copStats *tikv.CopRuntimeStats, respTime time.Duration) { s.copRespTime = append(s.copRespTime, respTime) - s.procKeys = append(s.procKeys, copStats.ProcessedKeys) + if copStats.ScanDetail != nil { + s.procKeys = append(s.procKeys, copStats.ScanDetail.ProcessedKeys) + } else { + s.procKeys = append(s.procKeys, 0) + } for k, v := range copStats.BackoffSleep { s.backoffSleep[k] += v } - s.totalProcessTime += copStats.ProcessTime - s.totalWaitTime += copStats.WaitTime + s.totalProcessTime += copStats.TimeDetail.ProcessTime + s.totalWaitTime += copStats.TimeDetail.WaitTime s.rpcStat.Merge(copStats.RegionRequestRuntimeStats) if copStats.CoprCacheHit { s.CoprCacheHitNum++ diff --git a/domain/domain.go b/domain/domain.go index 649b3e7b1287d..c581ea037fcc7 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1015,7 +1015,8 @@ func (do *Domain) StatsHandle() *handle.Handle { // CreateStatsHandle is used only for test. func (do *Domain) CreateStatsHandle(ctx sessionctx.Context) { - atomic.StorePointer(&do.statsHandle, unsafe.Pointer(handle.NewHandle(ctx, do.statsLease))) + h := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool) + atomic.StorePointer(&do.statsHandle, unsafe.Pointer(h)) } // StatsUpdating checks if the stats worker is updating. @@ -1040,7 +1041,7 @@ var RunAutoAnalyze = true // It should be called only once in BootstrapSession. func (do *Domain) UpdateTableStatsLoop(ctx sessionctx.Context) error { ctx.GetSessionVars().InRestrictedSQL = true - statsHandle := handle.NewHandle(ctx, do.statsLease) + statsHandle := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool) atomic.StorePointer(&do.statsHandle, unsafe.Pointer(statsHandle)) do.ddl.RegisterEventCh(statsHandle.DDLEventCh()) // Negative stats lease indicates that it is in test, it does not need update. @@ -1215,9 +1216,12 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { } } // update locally - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`FLUSH PRIVILEGES`) - if err != nil { - logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil { + _, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) + if err != nil { + logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + } } } diff --git a/executor/adapter.go b/executor/adapter.go index cc07ea991a6c6..5166cca747b41 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -827,6 +827,14 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, succ bool, hasMoreResults boo } sessVars.StmtCtx.RuntimeStatsColl.RegisterStats(a.Plan.ID(), statsWithCommit) } + // Record related SLI metrics. + if execDetail.CommitDetail != nil && execDetail.CommitDetail.WriteSize > 0 { + a.Ctx.GetTxnWriteThroughputSLI().AddTxnWriteSize(execDetail.CommitDetail.WriteSize, execDetail.CommitDetail.WriteKeys) + } + if execDetail.ScanDetail != nil && execDetail.ScanDetail.ProcessedKeys > 0 && sessVars.StmtCtx.AffectedRows() > 0 { + // Only record the read keys in write statement which affect row more than 0. + a.Ctx.GetTxnWriteThroughputSLI().AddReadKeys(execDetail.ScanDetail.ProcessedKeys) + } // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. a.LogSlowQuery(txnTS, succ, hasMoreResults) a.SummaryStmt(succ) @@ -907,6 +915,8 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { TimeTotal: costTime, TimeParse: sessVars.DurationParse, TimeCompile: sessVars.DurationCompile, + TimeOptimize: sessVars.DurationOptimization, + TimeWaitTS: sessVars.DurationWaitTS, IndexNames: indexNames, StatsInfos: statsInfos, CopTasks: copTaskInfo, @@ -919,6 +929,7 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { Prepared: a.isPreparedStmt, HasMoreResults: hasMoreResults, PlanFromCache: sessVars.FoundInPlanCache, + PlanFromBinding: sessVars.FoundInBinding, KVTotal: time.Duration(atomic.LoadInt64(&stmtDetail.WaitKVRespDuration)), PDTotal: time.Duration(atomic.LoadInt64(&stmtDetail.WaitPDRespDuration)), BackoffTotal: time.Duration(atomic.LoadInt64(&stmtDetail.BackoffDuration)), @@ -941,12 +952,12 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { logutil.SlowQueryLogger.Warn(sessVars.SlowLogFormat(slowItems)) if sessVars.InRestrictedSQL { totalQueryProcHistogramInternal.Observe(costTime.Seconds()) - totalCopProcHistogramInternal.Observe(execDetail.ProcessTime.Seconds()) - totalCopWaitHistogramInternal.Observe(execDetail.WaitTime.Seconds()) + totalCopProcHistogramInternal.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + totalCopWaitHistogramInternal.Observe(execDetail.TimeDetail.WaitTime.Seconds()) } else { totalQueryProcHistogramGeneral.Observe(costTime.Seconds()) - totalCopProcHistogramGeneral.Observe(execDetail.ProcessTime.Seconds()) - totalCopWaitHistogramGeneral.Observe(execDetail.WaitTime.Seconds()) + totalCopProcHistogramGeneral.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + totalCopWaitHistogramGeneral.Observe(execDetail.TimeDetail.WaitTime.Seconds()) } var userString string if sessVars.User != nil { @@ -1111,6 +1122,7 @@ func (a *ExecStmt) SummaryStmt(succ bool) { IsInternal: sessVars.InRestrictedSQL, Succeed: succ, PlanInCache: sessVars.FoundInPlanCache, + PlanInBinding: sessVars.FoundInBinding, ExecRetryCount: a.retryCount, StmtExecDetails: stmtDetail, Prepared: a.isPreparedStmt, diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 45e153aed93e9..3681be119fed1 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -150,9 +150,13 @@ func buildApproxPercentile(sctx sessionctx.Context, aggFuncDesc *aggregation.Agg base := basePercentile{percent: int(percent), baseAggFunc: baseAggFunc{args: aggFuncDesc.Args, ordinal: ordinal}} + evalType := aggFuncDesc.Args[0].GetType().EvalType() + if aggFuncDesc.Args[0].GetType().Tp == mysql.TypeBit { + evalType = types.ETString // same as other aggregate function + } switch aggFuncDesc.Mode { case aggregation.CompleteMode, aggregation.Partial1Mode, aggregation.FinalMode: - switch aggFuncDesc.Args[0].GetType().EvalType() { + switch evalType { case types.ETInt: return &percentileOriginal4Int{base} case types.ETReal: diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index d13bc754b0842..c2d3b745d12fc 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -236,6 +236,11 @@ type topNRows struct { currSize uint64 limitSize uint64 sepSize uint64 + // If sep is truncated, we need to append part of sep to result. + // In the following example, session.group_concat_max_len is 10 and sep is '---'. + // ('---', 'ccc') should be poped from heap, so '-' should be appended to result. + // eg: 'aaa---bbb---ccc' -> 'aaa---bbb-' + isSepTruncated bool } func (h topNRows) Len() int { @@ -296,6 +301,7 @@ func (h *topNRows) tryToAdd(row sortRow) (truncated bool) { } else { h.currSize -= uint64(h.rows[0].buffer.Len()) + h.sepSize heap.Pop(h) + h.isSepTruncated = true } } return true @@ -316,10 +322,11 @@ func (h *topNRows) concat(sep string, truncated bool) string { } buffer.Write(row.buffer.Bytes()) } - if truncated && uint64(buffer.Len()) < h.limitSize { - // append the last separator, because the last separator may be truncated in tryToAdd. + if h.isSepTruncated { buffer.WriteString(sep) - buffer.Truncate(int(h.limitSize)) + if uint64(buffer.Len()) > h.limitSize { + buffer.Truncate(int(h.limitSize)) + } } return buffer.String() } @@ -349,10 +356,11 @@ func (e *groupConcatOrder) AllocPartialResult() PartialResult { } p := &partialResult4GroupConcatOrder{ topN: &topNRows{ - desc: desc, - currSize: 0, - limitSize: e.maxLen, - sepSize: uint64(len(e.sep)), + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + isSepTruncated: false, }, } return PartialResult(p) @@ -449,10 +457,11 @@ func (e *groupConcatDistinctOrder) AllocPartialResult() PartialResult { } p := &partialResult4GroupConcatOrderDistinct{ topN: &topNRows{ - desc: desc, - currSize: 0, - limitSize: e.maxLen, - sepSize: uint64(len(e.sep)), + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + isSepTruncated: false, }, valSet: set.NewStringSet(), } diff --git a/executor/aggregate.go b/executor/aggregate.go index bcbb9495470d7..34230f1eb63c2 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -166,9 +166,10 @@ type HashAggExec struct { isChildReturnEmpty bool // After we support parallel execution for aggregation functions with distinct, // we can remove this attribute. - isUnparallelExec bool - prepared bool - executed bool + isUnparallelExec bool + parallelExecInitialized bool + prepared bool + executed bool memTracker *memory.Tracker // track memory usage. } @@ -204,36 +205,42 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, // Close implements the Executor Close interface. func (e *HashAggExec) Close() error { if e.isUnparallelExec { - e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupSet = nil e.partialResultMap = nil + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } return e.baseExecutor.Close() } - // `Close` may be called after `Open` without calling `Next` in test. - if !e.prepared { - close(e.inputCh) + if e.parallelExecInitialized { + // `Close` may be called after `Open` without calling `Next` in test. + if !e.prepared { + close(e.inputCh) + for _, ch := range e.partialOutputChs { + close(ch) + } + for _, ch := range e.partialInputChs { + close(ch) + } + close(e.finalOutputCh) + } + close(e.finishCh) for _, ch := range e.partialOutputChs { - close(ch) + for range ch { + } } for _, ch := range e.partialInputChs { - close(ch) + for range ch { + } } - close(e.finalOutputCh) - } - close(e.finishCh) - for _, ch := range e.partialOutputChs { - for range ch { + for range e.finalOutputCh { } - } - for _, ch := range e.partialInputChs { - for chk := range ch { - e.memTracker.Consume(-chk.MemoryUsage()) + e.executed = false + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) } } - for range e.finalOutputCh { - } - e.executed = false if e.runtimeStats != nil { var partialConcurrency, finalConcurrency int @@ -258,6 +265,11 @@ func (e *HashAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) + } + }) e.prepared = false e.memTracker = memory.NewTracker(e.id, -1) @@ -340,6 +352,8 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { } e.finalWorkers[i].finalResultHolderCh <- newFirstChunk(e) } + + e.parallelExecInitialized = true } func (w *HashAggPartialWorker) getChildInput() bool { @@ -838,6 +852,11 @@ func (e *StreamAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) + } + }) e.childResult = newFirstChunk(e.children[0]) e.executed = false e.isChildReturnEmpty = true @@ -858,8 +877,10 @@ func (e *StreamAggExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *StreamAggExec) Close() error { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } e.groupChecker.reset() return e.baseExecutor.Close() } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index ad48bc68ae314..fdf643a4a0782 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -630,6 +630,23 @@ func (s *testSuiteAgg) TestGroupConcatAggr(c *C) { // issue #9920 tk.MustQuery("select group_concat(123, null)").Check(testkit.Rows("")) + + // issue #23129 + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(cid int, sname varchar(100));") + tk.MustExec("insert into t1 values(1, 'Bob'), (1, 'Alice');") + tk.MustExec("insert into t1 values(3, 'Ace');") + tk.MustExec("set @@group_concat_max_len=5;") + rows := tk.MustQuery("select group_concat(sname order by sname) from t1 group by cid;") + rows.Check(testkit.Rows("Alice", "Ace")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 varchar(10));") + tk.MustExec("insert into t1 values('0123456789');") + tk.MustExec("insert into t1 values('12345');") + tk.MustExec("set @@group_concat_max_len=8;") + rows = tk.MustQuery("select group_concat(c1 order by c1) from t1 group by c1;") + rows.Check(testkit.Rows("01234567", "12345")) } func (s *testSuiteAgg) TestSelectDistinct(c *C) { @@ -1155,3 +1172,34 @@ func (s *testSuiteAgg) TestIssue19426(c *C) { tk.MustQuery("select a, b, sum(case when a < 1000 then b else 0.0 end) over (order by a) from t"). Check(testkit.Rows("1 11 11.0", "2 22 33.0", "3 33 66.0", "4 44 110.0")) } + +func (s *testSuiteAgg) TestIssue23277(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + + tk.MustExec("create table t(a tinyint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a smallint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a mediumint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a int(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a bigint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") +} diff --git a/executor/analyze.go b/executor/analyze.go index bcea8925221ce..3e894b99c776f 100755 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -523,6 +523,11 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range) (hists []*statis if err != nil { return nil, nil, err } + // When collation is enabled, we store the Key representation of the sampling data. So we set it to kind `Bytes` here + // to avoid to convert it to its Key representation once more. + if collectors[i].Samples[j].Value.Kind() == types.KindString { + collectors[i].Samples[j].Value.SetBytes(collectors[i].Samples[j].Value.GetBytes()) + } } hg, err := statistics.BuildColumn(e.ctx, int64(e.opts[ast.AnalyzeOptNumBuckets]), col.ID, collectors[i], &col.FieldType) if err != nil { diff --git a/executor/analyze_test.go b/executor/analyze_test.go index b12646a4cb279..764b1c38ce692 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/testkit" ) @@ -646,3 +647,26 @@ func (s *testSuite1) TestDefaultValForAnalyze(c *C) { tk.MustQuery("explain select * from t where a = 1").Check(testkit.Rows("IndexReader_6 1.00 root index:IndexRangeScan_5", "└─IndexRangeScan_5 1.00 cop[tikv] table:t, index:a(a) range:[1,1], keep order:false")) } + +func (s *testSerialSuite2) TestIssue20874(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a char(10) collate utf8mb4_unicode_ci not null, b char(20) collate utf8mb4_general_ci not null)") + tk.MustExec("insert into t values ('#', 'C'), ('$', 'c'), ('a', 'a')") + tk.MustExec("analyze table t") + tk.MustQuery("show stats_buckets where db_name = 'test' and table_name = 't'").Sort().Check(testkit.Rows( + "test t a 0 0 1 1 \x02\xd2 \x02\xd2", + "test t a 0 1 2 1 \x0e\x0f \x0e\x0f", + "test t a 0 2 3 1 \x0e3 \x0e3", + "test t b 0 0 1 1 \x00A \x00A", + "test t b 0 1 3 2 \x00C \x00C", + )) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a char(10) collate utf8mb4_general_ci not null)") + tk.MustExec("insert into t values ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文')") + tk.MustExec("analyze table t") +} diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index d6c1257651697..71c7f2f574fce 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -180,6 +180,9 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if err1 != nil && !kv.ErrNotExist.Equal(err1) { return err1 } + if idxKey == nil { + continue + } s := hack.String(idxKey) if _, found := dedup[s]; found { continue diff --git a/executor/brie.go b/executor/brie.go index 4dd86aa9fdb1b..492a3d68d8121 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/printer" "github.com/pingcap/tidb/util/sqlexec" ) @@ -403,6 +404,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // Execute implements glue.Session func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { + // FIXME: br relies on a deprecated API, it may be unsafe _, err := gs.se.(sqlexec.SQLExecutor).Execute(ctx, sql) return err } @@ -465,3 +467,7 @@ func (gs *tidbGlueSession) Record(name string, value uint64) { gs.info.archiveSize = value } } + +func (gs *tidbGlueSession) GetVersion() string { + return "TiDB\n" + printer.GetTiDBInfo() +} diff --git a/executor/brie_test.go b/executor/brie_test.go new file mode 100644 index 0000000000000..0116009efc368 --- /dev/null +++ b/executor/brie_test.go @@ -0,0 +1,28 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import . "github.com/pingcap/check" + +type testBRIESuite struct{} + +var _ = Suite(&testBRIESuite{}) + +func (s *testBRIESuite) TestGlueGetVersion(c *C) { + g := tidbGlueSession{} + version := g.GetVersion() + c.Assert(version, Matches, `(.|\n)*Release Version(.|\n)*`) + c.Assert(version, Matches, `(.|\n)*Git Commit Hash(.|\n)*`) + c.Assert(version, Matches, `(.|\n)*GoVersion(.|\n)*`) +} diff --git a/executor/builder.go b/executor/builder.go index 4e1fd5a1b4161..2462293a8d69c 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -737,6 +737,8 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor { Table: tbl, Columns: v.Columns, GenExprs: v.GenCols.Exprs, + isLoadData: true, + txnInUse: sync.Mutex{}, } err := insertVal.initInsertColumns() if err != nil { @@ -1315,11 +1317,11 @@ func (b *executorBuilder) getSnapshotTS() (uint64, error) { } snapshotTS := b.ctx.GetSessionVars().SnapshotTS - txn, err := b.ctx.Txn(true) - if err != nil { - return 0, err - } if snapshotTS == 0 { + txn, err := b.ctx.Txn(true) + if err != nil { + return 0, err + } snapshotTS = txn.StartTS() } b.snapshotTS = snapshotTS @@ -2295,26 +2297,6 @@ func containsLimit(execs []*tipb.Executor) bool { return false } -// When allow batch cop is 1, only agg / topN uses batch cop. -// When allow batch cop is 2, every query uses batch cop. -func (e *TableReaderExecutor) setBatchCop(v *plannercore.PhysicalTableReader) { - if e.storeType != kv.TiFlash || e.keepOrder { - return - } - switch e.ctx.GetSessionVars().AllowBatchCop { - case 1: - for _, p := range v.TablePlans { - switch p.(type) { - case *plannercore.PhysicalHashAgg, *plannercore.PhysicalStreamAgg, *plannercore.PhysicalTopN, *plannercore.PhysicalBroadCastJoin: - e.batchCop = true - } - } - case 2: - e.batchCop = true - } - return -} - func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableReader) (*TableReaderExecutor, error) { tablePlans := v.TablePlans if v.StoreType == kv.TiFlash { @@ -2349,8 +2331,8 @@ func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableRea plans: v.TablePlans, tablePlan: v.GetTablePlan(), storeType: v.StoreType, + batchCop: v.BatchCop, } - e.setBatchCop(v) e.buildVirtualColumnInfo() if containsLimit(dagReq.Executors) { e.feedback = statistics.NewQueryFeedback(0, nil, 0, ts.Desc) @@ -2628,6 +2610,9 @@ func buildNoRangeIndexMergeReader(b *executorBuilder, v *plannercore.PhysicalInd } collectTable := false e.tableRequest.CollectRangeCounts = &collectTable + if v.ExtraHandleCol != nil { + e.extraHandleIdx = v.ExtraHandleCol.Index + } return e, nil } diff --git a/executor/ddl.go b/executor/ddl.go index ae2d685a65ccf..ef903bd4e377d 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -309,8 +309,12 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("database", fullti.Schema.O), zap.String("table", fullti.Name.O), ) - sql := fmt.Sprintf("admin check table `%s`.`%s`", fullti.Schema.O, fullti.Name.O) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/ddl_test.go b/executor/ddl_test.go index 9b0b5c3443da1..e46302e545f16 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -291,6 +291,16 @@ func (s *testSuite6) TestViewRecursion(c *C) { tk.MustExec("drop view recursive_view1, t") } +func (s *testSuite6) TestIssue23027(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t23027") + tk.MustExec("create table t23027(a char(10))") + tk.MustExec("insert into t23027 values ('a'), ('a')") + tk.MustExec("create definer='root'@'localhost' view v23027 as select group_concat(a) from t23027;") + tk.MustQuery("select * from v23027").Check(testkit.Rows("a,a")) +} + func (s *testSuite6) TestIssue16250(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/executor/executor.go b/executor/executor.go index 852b05a72d702..818c7099634ba 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1188,6 +1188,11 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) + } + }) return e.open(ctx) } @@ -1207,8 +1212,10 @@ func (e *SelectionExec) open(ctx context.Context) error { // Close implements plannercore.Plan Close interface. func (e *SelectionExec) Close() error { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } e.selected = nil return e.baseExecutor.Close() } @@ -1699,7 +1706,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate sc.Priority = stmt.Priority case *ast.CreateTableStmt, *ast.AlterTableStmt: - // Make sure the sql_mode is strict when checking column default value. + sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() + sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || sc.AllowInvalidDate case *ast.LoadDataStmt: sc.DupKeyAsWarning = true sc.BadNullAsWarning = true @@ -1772,6 +1780,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars.StmtCtx = sc vars.PrevFoundInPlanCache = vars.FoundInPlanCache vars.FoundInPlanCache = false + vars.PrevFoundInBinding = vars.FoundInBinding + vars.FoundInBinding = false return } diff --git a/executor/executor_test.go b/executor/executor_test.go index f5b82c73c98aa..c2ccaf7cc2044 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -104,6 +104,7 @@ var _ = Suite(&testSuite{&baseTestSuite{}}) var _ = Suite(&testSuiteP1{&baseTestSuite{}}) var _ = Suite(&testSuiteP2{&baseTestSuite{}}) var _ = Suite(&testSuite1{}) +var _ = SerialSuites(&testSerialSuite2{}) var _ = Suite(&testSuite2{&baseTestSuite{}}) var _ = Suite(&testSuite3{&baseTestSuite{}}) var _ = Suite(&testSuite4{&baseTestSuite{}}) @@ -1134,6 +1135,32 @@ func (s *testSuiteP1) TestIssue5055(c *C) { result.Check(testkit.Rows("1 1")) } +// issue-23038: wrong key range of index scan for year column +func (s *testSuiteWithData) TestIndexScanWithYearCol(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (c1 year(4), c2 int, key(c1));") + tk.MustExec("insert into t values(2001, 1);") + + var input []string + var output []struct { + SQL string + Plan []string + Res []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain " + tt).Rows()) + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + }) + tk.MustQuery("explain " + tt).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...)) + } +} + func (s *testSuiteP2) TestUnion(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -1372,6 +1399,12 @@ func (s *testSuiteP2) TestUnion(c *C) { tk.MustExec("create table t(a int, b decimal(6, 3))") tk.MustExec("insert into t values(1, 1.000)") tk.MustQuery("select count(distinct a), sum(distinct a), avg(distinct a) from (select a from t union all select b from t) tmp;").Check(testkit.Rows("1 1.000 1.0000000")) + + // #issue 23832 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bit(20), b float, c double, d int)") + tk.MustExec("insert into t values(10, 10, 10, 10), (1, -1, 2, -2), (2, -2, 1, 1), (2, 1.1, 2.1, 10.1)") + tk.MustQuery("select a from t union select 10 order by a").Check(testkit.Rows("1", "2", "10")) } func (s *testSuite2) TestUnionLimit(c *C) { @@ -3007,6 +3040,10 @@ type testSuite1 struct { testSuiteWithCliBase } +type testSerialSuite2 struct { + testSuiteWithCliBase +} + func (s *testSuiteWithCliBase) SetUpSuite(c *C) { cli := &checkRequestClient{} hijackClient := func(c tikv.Client) tikv.Client { @@ -4064,6 +4101,23 @@ func (s *testSuiteP1) TestSelectPartition(c *C) { tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) tk.MustExec("commit") tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) + + // test partition function is scalar func + tk.MustExec("drop table if exists tscalar") + tk.MustExec(`create table tscalar (c1 int) partition by range (c1 % 30) ( + partition p0 values less than (0), + partition p1 values less than (10), + partition p2 values less than (20), + partition pm values less than (maxvalue));`) + tk.MustExec("insert into tscalar values(0), (10), (40), (50), (55)") + // test IN expression + tk.MustExec("insert into tscalar values(-0), (-10), (-40), (-50), (-55)") + tk.MustQuery("select * from tscalar where c1 in (55, 55)").Check(testkit.Rows("55")) + tk.MustQuery("select * from tscalar where c1 in (40, 40)").Check(testkit.Rows("40")) + tk.MustQuery("select * from tscalar where c1 in (40)").Check(testkit.Rows("40")) + tk.MustQuery("select * from tscalar where c1 in (-40)").Check(testkit.Rows("-40")) + tk.MustQuery("select * from tscalar where c1 in (-40, -40)").Check(testkit.Rows("-40")) + tk.MustQuery("select * from tscalar where c1 in (-1)").Check(testkit.Rows()) } func (s *testSuiteP1) TestDeletePartition(c *C) { @@ -6454,6 +6508,15 @@ func (s *testSuite) TestIssue20305(c *C) { tk.MustQuery("SELECT * FROM `t3` where y <= a").Check(testkit.Rows("2155 2156")) } +func (s *testSuite) TestIssue22817(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t3") + tk.MustExec("create table t3 (a year)") + tk.MustExec("insert into t3 values (1991), (\"1992\"), (\"93\"), (94)") + tk.MustQuery("select * from t3 where a >= NULL").Check(testkit.Rows()) +} + func (s *testSuite) TestIssue13953(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -6487,6 +6550,32 @@ func (s *testSuite) TestZeroDateTimeCompatibility(c *C) { } } +// https://github.com/pingcap/tidb/issues/24165. +func (s *testSuite) TestInvalidDateValueInCreateTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("set @@sql_mode='STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE';") + tk.MustGetErrCode("create table t (a datetime default '2999-00-00 00:00:00');", mysql.ErrInvalidDefault) + tk.MustGetErrCode("create table t (a datetime default '2999-02-30 00:00:00');", mysql.ErrInvalidDefault) + tk.MustExec("create table t (a datetime);") + tk.MustGetErrCode("alter table t modify column a datetime default '2999-00-00 00:00:00';", mysql.ErrInvalidDefault) + tk.MustExec("drop table if exists t;") + + tk.MustExec("set @@sql_mode = (select replace(@@sql_mode,'NO_ZERO_IN_DATE',''));") + tk.MustExec("set @@sql_mode = (select replace(@@sql_mode,'NO_ZERO_DATE',''));") + tk.MustExec("set @@sql_mode=(select concat(@@sql_mode, ',ALLOW_INVALID_DATES'));") + // Test create table with zero datetime as a default value. + tk.MustExec("create table t (a datetime default '2999-00-00 00:00:00');") + tk.MustExec("drop table if exists t;") + // Test create table with invalid datetime(02-30) as a default value. + tk.MustExec("create table t (a datetime default '2999-02-30 00:00:00');") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a datetime);") + tk.MustExec("alter table t modify column a datetime default '2999-00-00 00:00:00';") + tk.MustExec("drop table if exists t;") +} + func (s *testSuite) TestOOMActionPriority(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -6647,3 +6736,153 @@ func (s *testSuite) TestIssue22201(c *C) { tk.MustQuery("SELECT HEX(WEIGHT_STRING('ab' AS char(1000000000000000000)));").Check(testkit.Rows("")) tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1301 Result of weight_string() was larger than max_allowed_packet (67108864) - truncated")) } + +func (s *testSuiteP1) TestIssue22941(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists m, mp") + tk.MustExec(`CREATE TABLE m ( + mid varchar(50) NOT NULL, + ParentId varchar(50) DEFAULT NULL, + PRIMARY KEY (mid), + KEY ind_bm_parent (ParentId,mid) + )`) + // mp should have more columns than m + tk.MustExec(`CREATE TABLE mp ( + mpid bigint(20) unsigned NOT NULL DEFAULT '0', + mid varchar(50) DEFAULT NULL COMMENT '模块主键', + sid int, + PRIMARY KEY (mpid) + );`) + + tk.MustExec(`insert into mp values("1","1","0");`) + tk.MustExec(`insert into m values("0", "0");`) + rs := tk.MustQuery(`SELECT ( SELECT COUNT(1) FROM m WHERE ParentId = c.mid ) expand, bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL, sid FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) + rs.Check(testkit.Rows("1 1 0 ")) + + rs = tk.MustQuery(`SELECT bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) + rs.Check(testkit.Rows(" 1 0")) +} + +func (s *testSerialSuite2) TestTxnWriteThroughputSLI(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int key, b int)") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput", "return(true)"), IsNil) + defer func() { + err := failpoint.Disable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput") + c.Assert(err, IsNil) + }() + + mustExec := func(sql string) { + tk.MustExec(sql) + tk.Se.GetTxnWriteThroughputSLI().FinishExecuteStmt(time.Second, tk.Se.AffectedRows(), tk.Se.GetSessionVars().InTxn()) + } + errExec := func(sql string) { + _, err := tk.Exec(sql) + c.Assert(err, NotNil) + tk.Se.GetTxnWriteThroughputSLI().FinishExecuteStmt(time.Second, tk.Se.AffectedRows(), tk.Se.GetSessionVars().InTxn()) + } + + // Test insert in small txn + mustExec("insert into t values (1,3),(2,4)") + writeSLI := tk.Se.GetTxnWriteThroughputSLI() + c.Assert(writeSLI.IsInvalid(), Equals, false) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 2, writeSize: 58, readKeys: 0, writeKeys: 2, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test insert ... select ... from + mustExec("insert into t select b, a from t") + c.Assert(writeSLI.IsInvalid(), Equals, true) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: true, affectRow: 2, writeSize: 58, readKeys: 0, writeKeys: 2, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test for delete + mustExec("delete from t") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 4, writeSize: 76, readKeys: 0, writeKeys: 4, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test insert not in small txn + mustExec("begin") + for i := 0; i < 20; i++ { + mustExec(fmt.Sprintf("insert into t values (%v,%v)", i, i)) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + } + // The statement which affect rows is 0 shouldn't record into time. + mustExec("select count(*) from t") + mustExec("select * from t") + mustExec("insert into t values (20,20)") + c.Assert(writeSLI.IsSmallTxn(), Equals, false) + mustExec("commit") + c.Assert(writeSLI.IsInvalid(), Equals, false) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 21, writeSize: 609, readKeys: 0, writeKeys: 21, writeTime: 22s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test invalid when transaction has replace ... select ... from ... statement. + mustExec("delete from t") + tk.Se.GetTxnWriteThroughputSLI().Reset() + mustExec("begin") + mustExec("insert into t values (1,3),(2,4)") + mustExec("replace into t select b, a from t") + mustExec("commit") + c.Assert(writeSLI.IsInvalid(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: true, affectRow: 4, writeSize: 116, readKeys: 0, writeKeys: 4, writeTime: 3s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test clean last failed transaction information. + err := failpoint.Disable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput") + c.Assert(err, IsNil) + mustExec("begin") + mustExec("insert into t values (1,3),(2,4)") + errExec("commit") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 0, writeSize: 0, readKeys: 0, writeKeys: 0, writeTime: 0s") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput", "return(true)"), IsNil) + mustExec("begin") + mustExec("insert into t values (5, 6)") + mustExec("commit") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 1, writeSize: 29, readKeys: 0, writeKeys: 1, writeTime: 2s") + + // Test for reset + tk.Se.GetTxnWriteThroughputSLI().Reset() + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 0, writeSize: 0, readKeys: 0, writeKeys: 0, writeTime: 0s") +} + +func (s *testSerialSuite1) TestIssue24210(c *C) { + tk := testkit.NewTestKit(c, s.store) + + // for ProjectionExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockProjectionExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err := tk.Exec("select a from (select 1 as a, 2 as b) t") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock ProjectionExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockProjectionExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for HashAggExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockHashAggExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select sum(a) from (select 1 as a, 2 as b) t group by b") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock HashAggExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockHashAggExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for StreamAggExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStreamAggExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select sum(a) from (select 1 as a, 2 as b) t") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock StreamAggExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockStreamAggExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for SelectionExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockSelectionExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select * from (select 1 as a) t where a > 0") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock SelectionExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockSelectionExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) +} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index bdadf3e0fac92..efa72e7e2ca2d 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -104,8 +104,8 @@ func (s *testSuite9) TestExplainFor(c *C) { } } c.Assert(buf.String(), Matches, ""+ - "TableReader_5 10000.00 0 root time:.*, loops:1, cop_task:.*num: 1, max:.*, proc_keys: 0, rpc_num: 1, rpc_time: .*data:TableFullScan_4 N/A N/A\n"+ - "└─TableFullScan_4 10000.00 0 cop.* table:t1 .*keep order:false, stats:pseudo N/A N/A") + "TableReader_5 10000.00 0 root time:.*, loops:1, cop_task: {num:.*, max:.*, proc_keys: 0, rpc_num: 1, rpc_time:.*} data:TableFullScan_4 N/A N/A\n"+ + "└─TableFullScan_4 10000.00 0 cop.* table:t1 tikv_task:{time:.*, loops:.*} keep order:false, stats:pseudo N/A N/A") } tkRoot.MustQuery("select * from t1;") check() diff --git a/executor/grant.go b/executor/grant.go index 097e2d5dcc349..86b279727aaa5 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -16,7 +16,6 @@ package executor import ( "context" "encoding/json" - "fmt" "strings" "github.com/pingcap/errors" @@ -106,7 +105,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -114,7 +113,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -132,9 +131,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd) - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user) - _, err := internalSession.(sqlexec.SQLExecutor).Execute(ctx, sql) + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES (%?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd) if err != nil { return err } @@ -193,7 +190,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -274,29 +271,25 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast // initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege. func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, PRIV) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") return err } // initDBPrivEntry inserts a new row into mysql.DB with empty privilege. func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.DBTable, host, user, db) return err } // initTablePrivEntry inserts a new row into mysql.Tables_priv with empty privilege. func initTablePrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES ('%s', '%s', '%s', '%s', '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES (%?, %?, %?, %?, '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) return err } // initColumnPrivEntry inserts a new row into mysql.Columns_priv with empty privilege. func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string, col string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Column_name, Column_priv) VALUES ('%s', '%s', '%s', '%s', '%s', '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Column_name, Column_priv) VALUES (%?, %?, %?, %?, %?, '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) return err } @@ -309,8 +302,7 @@ func (e *GrantExec) grantGlobalPriv(ctx sessionctx.Context, user *ast.UserSpec) if err != nil { return errors.Trace(err) } - sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) - _, err = ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err = ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `UPDATE %n.%n SET PRIV=%? WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) return err } @@ -415,12 +407,16 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, int if priv.Priv == 0 { return nil } - asgns, err := composeGlobalPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, `UPDATE %n.%n SET `, mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user.User.Username, user.User.Hostname) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, ` WHERE User=%? AND Host=%?`, user.User.Username, user.User.Hostname) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -430,12 +426,16 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, interna if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, user.User.Username, user.User.Hostname, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", user.User.Username, user.User.Hostname, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -446,12 +446,16 @@ func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, inte dbName = e.ctx.GetSessionVars().CurrentDB } tblName := e.Level.TableName - asgns, err := composeTablePrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err := composeTablePrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tblName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user.User.Username, user.User.Hostname, dbName, tblName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -467,12 +471,16 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err := composeColumnPrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) if err != nil { return err } @@ -481,178 +489,143 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int } // composeGlobalPrivUpdate composes update stmt assignment list string for global scope privilege update. -func composeGlobalPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.Priv2UserCol)) - for _, v := range mysql.AllGlobalPrivs { - strs = append(strs, fmt.Sprintf(`%s='%s'`, mysql.Priv2UserCol[v], value)) +func composeGlobalPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + + for i, v := range mysql.AllGlobalPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + + k, ok := mysql.Priv2UserCol[v] + if !ok { + return errors.Errorf("Unknown priv %v", priv) + } + + sqlexec.MustFormatSQL(sql, "%n=%?", k, value) } - return fmt.Sprintf(`%s='%s'`, col, value), nil + return nil } // composeDBPrivUpdate composes update stmt assignment list for db scope privilege update. -func composeDBPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.AllDBPrivs)) - for _, p := range mysql.AllDBPrivs { - v, ok := mysql.Priv2UserCol[p] - if !ok { - return "", errors.Errorf("Unknown db privilege %v", priv) - } - strs = append(strs, fmt.Sprintf(`%s='%s'`, v, value)) +func composeDBPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil - } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - return fmt.Sprintf(`%s='%s'`, col, value), nil -} -// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. -func composeTablePrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - for _, p := range mysql.AllTablePrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown table privilege %v", p) - } - newTablePriv = addToSet(newTablePriv, v) - } - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { - currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) - if err != nil { - return "", err + for i, p := range mysql.AllDBPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") } - p, ok := mysql.Priv2SetStr[priv] + + v, ok := mysql.Priv2UserCol[p] if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return errors.Errorf("Unknown priv %v", priv) } - newTablePriv = addToSet(currTablePriv, p) - for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = addToSet(currColumnPriv, p) - break - } - } + sqlexec.MustFormatSQL(sql, "%n=%?", v, value) } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil + return nil } -func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - newTablePriv = "" - newColumnPriv = "" - } else { +func privUpdateForGrant(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = addToSet(cur, p) + return cur, nil +} + +// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. +func composeTablePrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + var tblPrivs, colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newTablePriv = deleteFromSet(currTablePriv, p) - + newTablePriv = setFromString(currTablePriv) + newColumnPriv = setFromString(currColumnPriv) + tblPrivs = []mysql.PrivilegeType{priv} for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = deleteFromSet(currColumnPriv, p) + // in case it is not a column priv + if cp == priv { + colPrivs = []mysql.PrivilegeType{priv} break } } + } else { + tblPrivs = mysql.AllTablePrivs + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil -} -// addToSet add a value to the set, e.g: -// addToSet("Select,Insert", "Update") returns "Select,Insert,Update". -func addToSet(set string, value string) string { - if set == "" { - return value + var err error + for _, p := range tblPrivs { + newTablePriv, err = privUpdateForGrant(newTablePriv, p) + if err != nil { + return err + } } - return fmt.Sprintf("%s,%s", set, value) -} -// deleteFromSet delete the value from the set, e.g: -// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". -func deleteFromSet(set string, value string) string { - sets := strings.Split(set, ",") - res := make([]string, 0, len(sets)) - for _, v := range sets { - if v != value { - res = append(res, v) + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) + if err != nil { + return err } } - return strings.Join(res, ",") + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil } // composeColumnPrivUpdateForGrant composes update stmt assignment list for column scope privilege update. -func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { +func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + var colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = addToSet(currColumnPriv, p) + newColumnPriv = setFromString(currColumnPriv) + colPrivs = []mysql.PrivilegeType{priv} + } else { + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil -} -func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - newColumnPriv = "" - } else { - currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) + var err error + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = deleteFromSet(currColumnPriv, p) } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) + return nil } // recordExists is a helper function to check if the sql returns any row. -func recordExists(ctx sessionctx.Context, sql string) (bool, error) { - recordSets, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) +func recordExists(ctx sessionctx.Context, sql string, args ...interface{}) (bool, error) { + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql, args...) if err != nil { return false, err } - rows, _, err := getRowsAndFields(ctx, recordSets) + rows, _, err := getRowsAndFields(ctx, rs) if err != nil { return false, err } @@ -661,43 +634,35 @@ func recordExists(ctx sessionctx.Context, sql string) (bool, error) { // globalPrivEntryExists checks if there is an entry with key user-host in mysql.global_priv. func globalPrivEntryExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) } // dbUserExists checks if there is an entry with key user-host-db in mysql.DB. func dbUserExists(ctx sessionctx.Context, name string, host string, db string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, name, host, db) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%?;`, mysql.SystemDB, mysql.DBTable, name, host, db) } // tableUserExists checks if there is an entry with key user-host-db-tbl in mysql.Tables_priv. func tableUserExists(ctx sessionctx.Context, name string, host string, db string, tbl string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?;`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) } // columnPrivEntryExists checks if there is an entry with key user-host-db-tbl-col in mysql.Columns_priv. func columnPrivEntryExists(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) } // getTablePriv gets current table scope privilege set from mysql.Tables_priv. // Return Table_priv and Column_priv. func getTablePriv(ctx sessionctx.Context, name string, host string, db string, tbl string) (string, string, error) { - sql := fmt.Sprintf(`SELECT Table_priv, Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Table_priv, Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) if err != nil { return "", "", err } - if len(rs) < 1 { - return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) - } var tPriv, cPriv string rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", "", err + return "", "", errors.Errorf("get table privilege fail for %s %s %s %s: %v", name, host, db, tbl, err) } if len(rows) < 1 { return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) @@ -717,17 +682,13 @@ func getTablePriv(ctx sessionctx.Context, name string, host string, db string, t // getColumnPriv gets current column scope privilege set from mysql.Columns_priv. // Return Column_priv. func getColumnPriv(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (string, error) { - sql := fmt.Sprintf(`SELECT Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) if err != nil { return "", err } - if len(rs) < 1 { - return "", errors.Errorf("get column privilege fail for %s %s %s %s", name, host, db, tbl) - } rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", err + return "", errors.Errorf("get column privilege fail for %s %s %s %s: %s", name, host, db, tbl, err) } if len(rows) < 1 { return "", errors.Errorf("get column privilege fail for %s %s %s %s %s", name, host, db, tbl, col) @@ -757,27 +718,18 @@ func getTargetSchemaAndTable(ctx sessionctx.Context, dbName, tableName string, i } // getRowsAndFields is used to extract rows from record sets. -func getRowsAndFields(ctx sessionctx.Context, recordSets []sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - - for i, rs := range recordSets { - tmp, err := getRowFromRecordSet(context.Background(), ctx, rs) - if err != nil { - return nil, nil, err - } - if err = rs.Close(); err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } +func getRowsAndFields(ctx sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { + if rs == nil { + return nil, nil, errors.Errorf("nil recordset") + } + rows, err := getRowFromRecordSet(context.Background(), ctx, rs) + if err != nil { + return nil, nil, err + } + if err = rs.Close(); err != nil { + return nil, nil, err } - return rows, fields, nil + return rows, rs.Fields(), nil } func getRowFromRecordSet(ctx context.Context, se sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index 52db90a83d587..d6e335c030522 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -553,8 +553,8 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*indexJoi if iw.hasPrefixCol { for i := range iw.outerCtx.keyCols { // If it's a prefix column. Try to fix it. - if iw.colLens[i] != types.UnspecifiedLength { - ranger.CutDatumByPrefixLen(&dLookUpKey[i], iw.colLens[i], iw.rowTypes[iw.keyCols[i]]) + if iw.colLens[iw.keyCols[i]] != types.UnspecifiedLength { + ranger.CutDatumByPrefixLen(&dLookUpKey[i], iw.colLens[iw.keyCols[i]], iw.rowTypes[iw.keyCols[i]]) } } // dLookUpKey is sorted and deduplicated at sortAndDedupLookUpContents. diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index 0e3e11316848d..48b4177b45fb6 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -252,3 +252,28 @@ func (s *testSuite5) TestIndexJoinEnumSetIssue19233(c *C) { } } } + +func (s *testSuite5) TestIssue23653(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (c_int int, c_str varchar(40), primary key(c_str), unique key(c_int), unique key(c_str))") + tk.MustExec("create table t2 (c_int int, c_str varchar(40), primary key(c_int, c_str(4)), key(c_int), unique key(c_str))") + tk.MustExec("insert into t1 values (1, 'cool buck'), (2, 'reverent keller')") + tk.MustExec("insert into t2 select * from t1") + tk.MustQuery("select /*+ inl_join(t2) */ * from t1, t2 where t1.c_str = t2.c_str and t1.c_int = t2.c_int and t1.c_int = 2").Check(testkit.Rows( + "2 reverent keller 2 reverent keller")) +} + +func (s *testSuite5) TestIssue23656(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (c_int int, c_str varchar(40), primary key(c_int, c_str(4)))") + tk.MustExec("create table t2 like t1") + tk.MustExec("insert into t1 values (1, 'clever jang'), (2, 'blissful aryabhata')") + tk.MustExec("insert into t2 select * from t1") + tk.MustQuery("select /*+ inl_join(t2) */ * from t1 join t2 on t1.c_str = t2.c_str where t1.c_int = t2.c_int;").Check(testkit.Rows( + "1 clever jang 1 clever jang", + "2 blissful aryabhata 2 blissful aryabhata")) +} diff --git a/executor/index_lookup_merge_join_test.go b/executor/index_lookup_merge_join_test.go index 93d1d9799d58b..7958867ca4f5c 100644 --- a/executor/index_lookup_merge_join_test.go +++ b/executor/index_lookup_merge_join_test.go @@ -9,7 +9,6 @@ import ( ) // TODO: reopen the index merge join in future. - //func (s *testSuite9) TestIndexLookupMergeJoinHang(c *C) { // c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/IndexMergeJoinMockOOM", `return(true)`), IsNil) // defer func() { diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index edf868913afd7..8e5e0c6a49f0d 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -102,6 +102,10 @@ type IndexMergeReaderExecutor struct { corColInAccess bool idxCols [][]*expression.Column colLens [][]int + + // extraHandleIdx indicates the index of extraHandleCol when the partial + // reader is TableReader. + extraHandleIdx int } // Open implements the Executor Open interface @@ -280,7 +284,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, var err error util.WithRecovery( func() { - _, err = worker.fetchHandles(ctx1, exitCh, fetchCh, e.resultCh, e.finished) + _, err = worker.fetchHandles(ctx1, exitCh, fetchCh, e.resultCh, e.finished, e.extraHandleIdx) }, e.handleHandlesFetcherPanic(ctx, e.resultCh, "partialTableWorker"), ) @@ -305,7 +309,7 @@ type partialTableWorker struct { } func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *lookupTableTask, resultCh chan<- *lookupTableTask, - finished <-chan struct{}) (count int64, err error) { + finished <-chan struct{}, extraHandleIdx int) (count int64, err error) { var chk *chunk.Chunk handleOffset := -1 if w.tableInfo.PKIsHandle { @@ -318,7 +322,7 @@ func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan str } } } else { - return 0, errors.Errorf("cannot find the column for handle") + handleOffset = extraHandleIdx } chk = chunk.NewChunkWithCapacity(retTypes(w.tableReader), w.maxChunkSize) diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index 647ab6911e358..06a39841cd7f9 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -14,8 +14,10 @@ package executor_test import ( + "fmt" . "github.com/pingcap/check" "github.com/pingcap/tidb/util/testkit" + "strings" ) func (s *testSuite1) TestSingleTableRead(c *C) { @@ -78,3 +80,38 @@ func (s *testSuite1) TestIssue16910(c *C) { tk.MustExec("insert into t2 values (0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9), (10, 10), (11, 11), (12, 12), (13, 13), (14, 14), (15, 15), (16, 16), (17, 17), (18, 18), (19, 19), (20, 20), (21, 21), (22, 22), (23, 23);") tk.MustQuery("select /*+ USE_INDEX_MERGE(t1, a, b) */ * from t1 partition (p0) join t2 partition (p1) on t1.a = t2.a where t1.a < 40 or t1.b < 30;").Check(testkit.Rows("1 1 1 1")) } + +func (s *testSuite1) TestIssue23569(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists tt;") + tk.MustExec(`create table tt(id bigint(20) NOT NULL,create_time bigint(20) NOT NULL DEFAULT '0' ,driver varchar(64), PRIMARY KEY (id,create_time)) + PARTITION BY RANGE ( create_time ) ( + PARTITION p201901 VALUES LESS THAN (1577808000), + PARTITION p202001 VALUES LESS THAN (1585670400), + PARTITION p202002 VALUES LESS THAN (1593532800), + PARTITION p202003 VALUES LESS THAN (1601481600), + PARTITION p202004 VALUES LESS THAN (1609430400), + PARTITION p202101 VALUES LESS THAN (1617206400), + PARTITION p202102 VALUES LESS THAN (1625068800), + PARTITION p202103 VALUES LESS THAN (1633017600), + PARTITION p202104 VALUES LESS THAN (1640966400), + PARTITION p202201 VALUES LESS THAN (1648742400), + PARTITION p202202 VALUES LESS THAN (1656604800), + PARTITION p202203 VALUES LESS THAN (1664553600), + PARTITION p202204 VALUES LESS THAN (1672502400), + PARTITION p202301 VALUES LESS THAN (1680278400) + );`) + tk.MustExec("insert tt value(1, 1577807000, 'jack'), (2, 1577809000, 'mike'), (3, 1585670500, 'right'), (4, 1601481500, 'hello');") + tk.MustExec("set @@tidb_enable_index_merge=true;") + rows := tk.MustQuery("explain select count(*) from tt partition(p202003) where _tidb_rowid is null or (_tidb_rowid>=1 and _tidb_rowid<100);").Rows() + containsIndexMerge := false + for _, r := range rows { + if strings.Contains(fmt.Sprintf("%s", r[0]), "IndexMerge") { + containsIndexMerge = true + break + } + } + c.Assert(containsIndexMerge, IsTrue) + tk.MustQuery("select count(*) from tt partition(p202003) where _tidb_rowid is null or (_tidb_rowid>=1 and _tidb_rowid<100);").Check(testkit.Rows("1")) +} diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index de9189548d16d..23cacc1558e76 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -154,10 +154,16 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex } func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, count from mysql.stats_meta") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + rowCountMap := make(map[int64]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) @@ -173,10 +179,16 @@ type tableHistID struct { } func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + colLengthMap := make(map[tableHistID]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) @@ -684,6 +696,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx sessionctx.Context, schema nil, // PARTITION_COMMENT nil, // NODEGROUP nil, // TABLESPACE_NAME + nil, // TIDB_PARTITION_ID ) rows = append(rows, record) } else { @@ -727,6 +740,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx sessionctx.Context, schema pi.Comment, // PARTITION_COMMENT nil, // NODEGROUP nil, // TABLESPACE_NAME + pi.ID, // TIDB_PARTITION_ID ) rows = append(rows, record) } diff --git a/executor/infoschema_reader_test.go b/executor/infoschema_reader_test.go index 993226ef65963..98dff1fbddb1d 100644 --- a/executor/infoschema_reader_test.go +++ b/executor/infoschema_reader_test.go @@ -456,7 +456,10 @@ func (s *testInfoschemaTableSerialSuite) TestPartitionsTable(c *C) { tk.MustQuery("select PARTITION_NAME, TABLE_ROWS, AVG_ROW_LENGTH, DATA_LENGTH, INDEX_LENGTH from information_schema.PARTITIONS where table_name='test_partitions_1';").Check( testkit.Rows(" 3 18 54 6")) - tk.MustExec("DROP TABLE `test_partitions`;") + pid, err := strconv.Atoi(tk.MustQuery("select TIDB_PARTITION_ID from information_schema.partitions where table_name = 'test_partitions';").Rows()[0][0].(string)) + c.Assert(err, IsNil) + c.Assert(pid, Greater, 0) + tk.MustExec("drop table test_partitions") } func (s *testInfoschemaTableSuite) TestMetricTables(c *C) { diff --git a/executor/insert_common.go b/executor/insert_common.go index 6c2f27700ed10..d6fbcfa74b353 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "math" + "sync" "time" "github.com/opentracing/opentracing-go" @@ -81,6 +82,12 @@ type InsertValues struct { memTracker *memory.Tracker stats *InsertRuntimeStat + + // LoadData use two goroutines. One for generate batch data, + // The other one for commit task, which will invalid txn. + // We use mutex to protect routine from using invalid txn. + isLoadData bool + txnInUse sync.Mutex } type defaultVal struct { @@ -418,6 +425,9 @@ func insertRowsFromSelect(ctx context.Context, base insertCommon) error { batchSize := sessVars.DMLBatchSize memUsageOfRows := int64(0) memTracker := e.memTracker + // In order to ensure the correctness of the `transaction write throughput` SLI statistics, + // just ignore the transaction which contain `insert|replace into ... select ... from ...` statement. + e.ctx.GetTxnWriteThroughputSLI().SetInvalid() for { err := Next(ctx, selectExec, chk) if err != nil { @@ -857,10 +867,6 @@ func (e *InsertValues) adjustAutoRandomDatum(ctx context.Context, d types.Datum, // Change NULL to auto id. // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. if d.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { - _, err := e.ctx.Txn(true) - if err != nil { - return types.Datum{}, errors.Trace(err) - } recordID, err = e.allocAutoRandomID(&c.FieldType) if err != nil { return types.Datum{}, err @@ -894,6 +900,14 @@ func (e *InsertValues) allocAutoRandomID(fieldType *types.FieldType) (int64, err if tables.OverflowShardBits(autoRandomID, tableInfo.AutoRandomBits, layout.TypeBitsLength, layout.HasSignBit) { return 0, autoid.ErrAutoRandReadFailed } + if e.isLoadData { + e.txnInUse.Lock() + defer e.txnInUse.Unlock() + } + _, err = e.ctx.Txn(true) + if err != nil { + return 0, err + } shard := tables.CalcShard(tableInfo.AutoRandomBits, e.ctx.GetSessionVars().TxnCtx.StartTS, layout.TypeBitsLength, layout.HasSignBit) autoRandomID |= shard return autoRandomID, nil diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index f243db364f0d8..f15dd6ef5e6ff 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -161,7 +161,13 @@ func (n *metricNode) getLabelValue(label string) *metricValue { } func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { - rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(context.Background(), query) + exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), query) + if err != nil { + return err + } + + rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index 4d8bd69ca3edb..5bb858a8224e9 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -141,8 +141,12 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct // Get cluster info. e.instanceToStatusAddress = make(map[string]string) e.statusToInstanceAddress = make(map[string]string) - sql := "select instance,status_address from information_schema.cluster_info;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,status_address from information_schema.cluster_info;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) } @@ -247,16 +251,22 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C "storage.data-dir", "storage.block-cache.capacity", } - sql := fmt.Sprintf("select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in ('%s') group by type, `key` having c > 1", - strings.Join(ignoreConfigKey, "','")) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) } generateDetail := func(tp, item string) string { - query := fmt.Sprintf("select value, instance from information_schema.cluster_config where type='%s' and `key`='%s';", tp, item) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, query) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", @@ -318,13 +328,18 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, cas := range cases { if !filter.enable(cas.key) { continue } - sql := fmt.Sprintf("select instance from information_schema.cluster_config where type = '%s' and `key` = '%s' and value = '%s'", - cas.tp, cas.key, cas.value) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + stmt, err := exec.ParseWithParams(ctx, "select instance from information_schema.cluster_config where type = %? and %n = %? and value = %?", cas.tp, "key", cas.key, cas.value) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -350,8 +365,12 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct if !filter.enable(item) { return nil } - sql := "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -375,8 +394,10 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - sql = "select instance, value from metrics_schema.node_total_memory where time=now()" - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err = exec.ParseWithParams(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -438,9 +459,13 @@ func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, e } func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + var rows []chunk.Row // check the configuration consistent - sql := "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) } @@ -594,6 +619,9 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if filter.enable(rule.item) { def, found := infoschema.MetricTableMap[rule.tbl] @@ -601,9 +629,13 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", rule.tbl)) continue } - sql := fmt.Sprintf("select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", + sql.Reset() + fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -649,10 +681,16 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se return nil } condition := filter.timeRange.Condition() - sql := fmt.Sprintf(`select t1.job,t1.instance, t2.min_time from + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -675,8 +713,12 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se results = append(results, result) } // Check from log. - sql = fmt.Sprintf("select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) + stmt, err = exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -790,24 +832,30 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string + sql.Reset() if len(rule.configKey) > 0 { - sql = fmt.Sprintf("select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from "+ - "(select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join "+ - "(select instance, value from information_schema.cluster_config where type='tikv' and `key` = '%[3]s') as t2 join "+ - "(select instance,status_address from information_schema.cluster_info where type='tikv') as t3 "+ - "on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)", rule.component, rule.threshold, rule.configKey, condition) + fmt.Fprintf(sql, `select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from + (select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join + (select instance, value from information_schema.cluster_config where type='tikv' and %[5]s = '%[3]s') as t2 join + (select instance,status_address from information_schema.cluster_info where type='tikv') as t3 + on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)`, rule.component, rule.threshold, rule.configKey, condition, "`key`") } else { - sql = fmt.Sprintf("select t1.instance, t1.cpu, %[2]f from "+ - "(select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 "+ - "where t1.cpu > %[2]f;", rule.component, rule.threshold, condition) + fmt.Fprintf(sql, `select t1.instance, t1.cpu, %[2]f from + (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 + where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -957,11 +1005,13 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string cond := condition if len(rule.condition) > 0 { cond = fmt.Sprintf("%s and %s", cond, rule.condition) @@ -969,12 +1019,16 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess if rule.factor == 0 { rule.factor = 1 } + sql.Reset() if rule.isMin { - sql = fmt.Sprintf("select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) } else { - sql = fmt.Sprintf("select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1150,12 +1204,17 @@ func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sess func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.getItem()) { continue } sql := rule.genSQL(filter.timeRange) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1170,8 +1229,15 @@ func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionF func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { condition := filter.timeRange.Condition() threshold := 50.0 - sql := fmt.Sprintf(`select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) return nil @@ -1179,12 +1245,18 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx var results []inspectionResult for _, row := range rows { address := row.GetString(0) - sql := fmt.Sprintf(`select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) - subRows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) + var subRows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue } + lastValue := float64(0) for i, subRows := range subRows { v := subRows.GetFloat64(1) diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index 37aef042a62f9..2a01cab6dc402 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -458,7 +458,12 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc sql = fmt.Sprintf("select avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/load_data.go b/executor/load_data.go index d69e66c33df6d..0449a78c4b6f7 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -199,6 +199,8 @@ func (e *LoadDataInfo) CommitOneTask(ctx context.Context, task CommitTask) error logutil.Logger(ctx).Error("commit error commit", zap.Error(err)) return err } + e.txnInUse.Lock() + defer e.txnInUse.Unlock() // Make sure that there are no retries when committing. if err = e.Ctx.RefreshTxnCtx(ctx); err != nil { logutil.Logger(ctx).Error("commit error refresh", zap.Error(err)) diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index 91cb8159bfc27..ff582c23b9935 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -189,7 +189,7 @@ type MetricsSummaryRetriever struct { retrieved bool } -func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { +func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { if e.retrieved || e.extractor.SkipRequest { return nil, nil } @@ -229,7 +229,12 @@ func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Co name, util.MetricSchemaName.L, condition) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -306,7 +311,12 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index 8bb55c16f52e5..76cdc74ea1d11 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -35,8 +35,12 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/point_get.go b/executor/point_get.go index 0fdeca2de81fe..f9d784e95ce6c 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -271,6 +271,9 @@ func (e *PointGetExecutor) getAndLock(ctx context.Context, key kv.Key) (val []by } func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) error { + if len(key) == 0 { + return nil + } if e.lock { seVars := e.ctx.GetSessionVars() lockCtx := newLockCtx(seVars, e.lockWaitTime) @@ -345,8 +348,11 @@ func encodeIndexKey(e *baseExecutor, tblInfo *model.TableInfo, idxInfo *model.In str, err = idxVals[i].ToString() idxVals[i].SetString(str, colInfo.FieldType.Collate) } else { + // If a truncated error or an overflow error is thrown when converting the type of `idxVal[i]` to + // the type of `colInfo`, the `idxVal` does not exist in the `idxInfo` for sure. idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, true, false) - if types.ErrOverflow.Equal(err) { + if types.ErrOverflow.Equal(err) || types.ErrDataTooLong.Equal(err) || + types.ErrTruncated.Equal(err) || types.ErrTruncatedWrongVal.Equal(err) { return nil, false, kv.ErrNotExist } } diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 0269a1f0d65d5..b70677c17be72 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -125,6 +125,25 @@ func (s *testPointGetSuite) TestPointGetOverflow(c *C) { tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=127").Check(testkit.Rows("127")) } +// Close issue #22839 +func (s *testPointGetSuite) TestPointGetDataTooLong(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists PK_1389;") + tk.MustExec("CREATE TABLE `PK_1389` ( " + + " `COL1` bit(1) NOT NULL," + + " `COL2` varchar(20) DEFAULT NULL," + + " `COL3` datetime DEFAULT NULL," + + " `COL4` bigint(20) DEFAULT NULL," + + " `COL5` float DEFAULT NULL," + + " PRIMARY KEY (`COL1`)" + + ");") + tk.MustExec("insert into PK_1389 values(0, \"皟钹糁泅埞礰喾皑杏灚暋蛨歜檈瓗跾咸滐梀揉\", \"7701-12-27 23:58:43\", 4806951672419474695, -1.55652e38);") + tk.MustQuery("select count(1) from PK_1389 where col1 = 0x30;").Check(testkit.Rows("0")) + tk.MustQuery("select count(1) from PK_1389 where col1 in ( 0x30);").Check(testkit.Rows("0")) + tk.MustExec("drop table if exists PK_1389;") +} + func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test;`) diff --git a/executor/prepared.go b/executor/prepared.go index f57a1806b1ed9..9f8a427b84afa 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -116,6 +116,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { err error ) if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok { + // FIXME: ok... yet another parse API, may need some api interface clean. stmts, err = sqlParser.ParseSQL(e.sqlText, charset, collation) } else { p := parser.New() diff --git a/executor/projection.go b/executor/projection.go index c36d435ab231a..776502d07554f 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util" @@ -85,6 +86,11 @@ func (e *ProjectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) + } + }) return e.open(ctx) } @@ -290,7 +296,9 @@ func (e *ProjectionExec) drainOutputCh(ch chan *projectionOutput) { // Close implements the Executor Close interface. func (e *ProjectionExec) Close() error { - if e.isUnparallelExec() { + // if e.baseExecutor.Open returns error, e.childResult will be nil, see https://github.com/pingcap/tidb/issues/24210 + // for more information + if e.isUnparallelExec() && e.childResult != nil { e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil } diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index 3d0752e08463d..5783438813954 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -37,8 +37,12 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/revoke.go b/executor/revoke.go index 7e51aa8ac82a4..fb722c89acf40 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -15,7 +15,7 @@ package executor import ( "context" - "fmt" + "strings" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -73,7 +73,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -81,7 +81,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -103,7 +103,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -166,12 +166,15 @@ func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.Pr } func (e *RevokeExec) revokeGlobalPriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error { - asgns, err := composeGlobalPrivUpdate(priv.Priv, "N") + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user, host) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%?", user, host) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -180,12 +183,16 @@ func (e *RevokeExec) revokeDBPriv(internalSession sessionctx.Context, priv *ast. if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "N") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, userName, host, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", userName, host, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -194,12 +201,16 @@ func (e *RevokeExec) revokeTablePriv(internalSession sessionctx.Context, priv *a if err != nil { return err } - asgns, err := composeTablePrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err = composeTablePrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user, host, dbName, tbl.Meta().Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user, host, dbName, tbl.Meta().Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -208,20 +219,80 @@ func (e *RevokeExec) revokeColumnPriv(internalSession sessionctx.Context, priv * if err != nil { return err } + sql := new(strings.Builder) for _, c := range priv.Cols { col := table.FindCol(tbl.Cols(), c.Name.L) if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + sql.Reset() + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err = composeColumnPrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + if err != nil { + return err + } + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) + if err != nil { + return err + } + } + return nil +} + +func privUpdateForRevoke(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = deleteFromSet(cur, p) + return cur, nil +} + +func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + + if priv != mysql.AllPriv { + currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) + if err != nil { + return err + } + + newTablePriv = setFromString(currTablePriv) + newTablePriv, err = privUpdateForRevoke(newTablePriv, priv) + if err != nil { + return err + } + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) + if err != nil { + return err + } + } + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil +} + +func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + + if priv != mysql.AllPriv { + currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user, host, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) if err != nil { return err } } + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) return nil } diff --git a/executor/set.go b/executor/set.go index 75b6b0db95c60..c1753f3ed1313 100644 --- a/executor/set.go +++ b/executor/set.go @@ -184,6 +184,10 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e sessionVars.StmtCtx.AppendWarning(fmt.Errorf("Set operation for '%s' will not take effect", variable.TiDBFoundInPlanCache)) return nil } + if name == variable.TiDBFoundInBinding { + sessionVars.StmtCtx.AppendWarning(fmt.Errorf("Set operation for '%s' will not take effect", variable.TiDBFoundInBinding)) + return nil + } err = variable.SetSessionSystemVar(sessionVars, name, value) if err != nil { return err diff --git a/executor/show.go b/executor/show.go index bb856ac7657e0..85b4bbf31e784 100644 --- a/executor/show.go +++ b/executor/show.go @@ -284,12 +284,17 @@ func (e *ShowExec) fetchShowBind() error { } func (e *ShowExec) fetchShowEngines() error { - sql := `SELECT * FROM information_schema.engines` - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { e.result.AppendRow(row) } @@ -410,16 +415,32 @@ func (e *ShowExec) fetchShowTableStatus() error { return ErrBadDB.GenWithStackByArgs(e.DBName) } - sql := fmt.Sprintf(`SELECT - table_name, engine, version, row_format, table_rows, - avg_row_length, data_length, max_data_length, index_length, - data_free, auto_increment, create_time, update_time, check_time, - table_collation, IFNULL(checksum,''), create_options, table_comment - FROM information_schema.tables - WHERE table_schema='%s' ORDER BY table_name`, e.DBName) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT + table_name, engine, version, row_format, table_rows, + avg_row_length, data_length, max_data_length, index_length, + data_free, auto_increment, create_time, update_time, check_time, + table_collation, IFNULL(checksum,''), create_options, table_comment + FROM information_schema.tables + WHERE lower(table_schema)=%? ORDER BY table_name`, e.DBName.L) + if err != nil { + return errors.Trace(err) + } + var snapshot uint64 + txn, err := e.ctx.Txn(false) + if err != nil { + return errors.Trace(err) + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if e.ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = e.ctx.GetSessionVars().SnapshotTS + } + + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return errors.Trace(err) } @@ -1199,22 +1220,32 @@ func (e *ShowExec) fetchShowCreateUser() error { } } - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, - mysql.SystemDB, mysql.UserTable, userName, hostName) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, hostName) + if err != nil { + return errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return errors.Trace(err) } + if len(rows) == 0 { + // FIXME: the error returned is not escaped safely return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - sql = fmt.Sprintf(`SELECT PRIV FROM %s.%s WHERE User='%s' AND Host='%s'`, - mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) - rows, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + + stmt, err = exec.ParseWithParams(context.TODO(), `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } + rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + require := "NONE" if len(rows) == 1 { privData := rows[0].GetString(0) @@ -1225,6 +1256,7 @@ func (e *ShowExec) fetchShowCreateUser() error { } require = privValue.RequireStr() } + // FIXME: the returned string is not escaped safely showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE %s PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), require) e.appendRow([]interface{}{showStr}) diff --git a/executor/show_test.go b/executor/show_test.go index 19c65fc8e6159..251b0efba8567 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -19,7 +19,6 @@ import ( "strings" . "github.com/pingcap/check" - "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -144,6 +143,16 @@ func (s *testSuite5) TestShowErrors(c *C) { tk.MustQuery("show errors").Check(testutil.RowsWithSep("|", "Error|1050|Table 'test.show_errors' already exists")) } +func (s *testSuite5) TestShowWarningsForExprPushdown(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + testSQL := `create table if not exists show_warnings_expr_pushdown (a int, value date)` + tk.MustExec(testSQL) + tk.MustExec("explain select * from show_warnings_expr_pushdown where date_add(value, interval 1 day) = '2020-01-01'") + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt) can not be pushed to tikv")) +} + func (s *testSuite5) TestShowGrantsPrivilege(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("create user show_grants") @@ -481,12 +490,13 @@ func (s *testSuite5) TestShowTableStatus(c *C) { // It's not easy to test the result contents because every time the test runs, "Create_time" changed. tk.MustExec("show table status;") rs, err := tk.Exec("show table status;") - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) c.Assert(rs, NotNil) rows, err := session.GetRows4Test(context.Background(), tk.Se, rs) - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) err = rs.Close() - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) for i := range rows { row := rows[i] @@ -503,10 +513,34 @@ func (s *testSuite5) TestShowTableStatus(c *C) { partition p2 values less than (maxvalue) );`) rs, err = tk.Exec("show table status from test like 'tp';") - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) c.Assert(rows[0].GetString(16), Equals, "partitioned") + + tk.MustExec("create database UPPER_CASE") + tk.MustExec("use UPPER_CASE") + tk.MustExec("create table t (i int)") + rs, err = tk.Exec("show table status") + c.Assert(err, IsNil) + c.Assert(rs, NotNil) + rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) + c.Assert(err, IsNil) + err = rs.Close() + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) + + tk.MustExec("use upper_case") + rs, err = tk.Exec("show table status") + c.Assert(err, IsNil) + c.Assert(rs, NotNil) + rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) + c.Assert(err, IsNil) + err = rs.Close() + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) + + tk.MustExec("drop database UPPER_CASE") } func (s *testSuite5) TestShowSlow(c *C) { diff --git a/executor/shuffle.go b/executor/shuffle.go index 7de58e7139880..2395819726c9e 100644 --- a/executor/shuffle.go +++ b/executor/shuffle.go @@ -127,6 +127,7 @@ func (e *ShuffleExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *ShuffleExec) Close() error { + var firstErr error if !e.prepared { for _, w := range e.workers { close(w.inputHolderCh) @@ -139,6 +140,9 @@ func (e *ShuffleExec) Close() error { for _, w := range e.workers { for range w.inputCh { } + if err := w.childExec.Close(); err != nil && firstErr == nil { + firstErr = err + } } for range e.outputCh { // workers exit before `e.outputCh` is closed. } @@ -150,12 +154,13 @@ func (e *ShuffleExec) Close() error { e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, runtimeStats) } - err := e.dataSource.Close() - err1 := e.baseExecutor.Close() - if err != nil { - return errors.Trace(err) + if err := e.dataSource.Close(); err != nil && firstErr == nil { + firstErr = err + } + if err := e.baseExecutor.Close(); err != nil && firstErr == nil { + firstErr = err } - return errors.Trace(err1) + return errors.Trace(firstErr) } func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { diff --git a/executor/simple.go b/executor/simple.go index 2f1dca0cb982d..9a91191035cbd 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -80,7 +80,7 @@ func (e *baseExecutor) releaseSysSession(ctx sessionctx.Context) { } dom := domain.GetDomain(e.ctx) sysSessionPool := dom.SysSessionPool() - if _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "rollback"); err != nil { ctx.(pools.Resource).Close() return } @@ -151,23 +151,25 @@ func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, u := range s.UserList { if u.Hostname == "" { u.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", u.Username, u.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -199,42 +201,45 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } for _, role := range s.RoleList { - sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username) checker := privilege.GetPrivilegeManager(e.ctx) ok := checker.FindEdge(e.ctx, role, user) if ok { - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values(%?, %?, %?, %?);", user.Hostname, user.Username, role.Hostname, role.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } else { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -256,31 +261,34 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -288,29 +296,10 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (err error) { checker := privilege.GetPrivilegeManager(e.ctx) - user, sql := s.UserList[0], "" + user := s.UserList[0] if user.Hostname == "" { user.Hostname = "%" } - switch s.SetRoleOpt { - case ast.SetRoleNone: - sql = fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - case ast.SetRoleAll: - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - case ast.SetRoleRegular: - sql = "INSERT IGNORE INTO mysql.default_roles values" - for i, role := range s.RoleList { - ok := checker.FindEdge(e.ctx, role, user) - if !ok { - return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) - } - sql += fmt.Sprintf("('%s', '%s', '%s', '%s')", user.Hostname, user.Username, role.Hostname, role.Username) - if i != len(s.RoleList)-1 { - sql += "," - } - } - } restrictedCtx, err := e.getSysSession() if err != nil { @@ -319,27 +308,48 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } - deleteSQL := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), deleteSQL); err != nil { + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + switch s.SetRoleOpt { + case ast.SetRoleNone: + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + case ast.SetRoleAll: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + case ast.SetRoleRegular: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values") + for i, role := range s.RoleList { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + ok := checker.FindEdge(e.ctx, role, user) + if !ok { + return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) + } + sqlexec.MustFormatSQL(sql, "(%?, %?, %?, %?)", user.Hostname, user.Username, role.Hostname, role.Username) + } + } + + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -595,16 +605,17 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } + sql := new(strings.Builder) for _, user := range s.Users { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { return errors.Trace(err) } if !exists { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", user.String()) @@ -613,23 +624,26 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { if role.Hostname == "" { role.Hostname = "%" } - sql := fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE FROM_HOST='%s' and FROM_USER='%s' and TO_HOST='%s' and TO_USER='%s'`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE FROM_HOST=%? and FROM_USER=%? and TO_HOST=%? and TO_USER=%?`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } - sql = fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE DEFAULT_ROLE_HOST='%s' and DEFAULT_ROLE_USER='%s' and HOST='%s' and USER='%s'`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE DEFAULT_ROLE_HOST=%? and DEFAULT_ROLE_USER=%? and HOST=%? and USER=%?`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -687,9 +701,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm return err } - users := make([]string, 0, len(s.Specs)) - privs := make([]string, 0, len(s.Specs)) + sql := new(strings.Builder) + if s.IsCreateRole { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string, Account_locked) VALUES `, mysql.SystemDB, mysql.UserTable) + } else { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES `, mysql.SystemDB, mysql.UserTable) + } + + users := make([]*auth.UserIdentity, 0, len(s.Specs)) for _, spec := range s.Specs { + if len(users) > 0 { + sqlexec.MustFormatSQL(sql, ",") + } exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { return err1 @@ -710,26 +733,17 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, pwd) if s.IsCreateRole { - user = fmt.Sprintf(`('%s', '%s', '%s', 'Y')`, spec.User.Hostname, spec.User.Username, pwd) - } - users = append(users, user) - - if len(privData) != 0 { - priv := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, hack.String(privData)) - privs = append(privs, priv) + sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, "Y") + } else { + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd) } + users = append(users, spec.User) } if len(users) == 0 { return nil } - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - if s.IsCreateRole { - sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - } - restrictedCtx, err := e.getSysSession() if err != nil { return err @@ -737,27 +751,34 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } - _, err = sqlExecutor.Execute(context.Background(), sql) + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if len(privs) != 0 { - sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (Host, User, Priv) VALUES %s", mysql.SystemDB, mysql.GlobalPrivTable, strings.Join(privs, ", ")) - _, err = sqlExecutor.Execute(context.Background(), sql) + if len(privData) != 0 { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO %n.%n (Host, User, Priv) VALUES ", mysql.SystemDB, mysql.GlobalPrivTable) + for i, user := range users { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, user.Hostname, user.Username, string(hack.String(privData))) + } + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return errors.Trace(err) } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -806,17 +827,22 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if !ok { return errors.Trace(ErrPasswordFormat) } - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string = '%s' WHERE Host = '%s' and User = '%s';`, - mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } if len(privData) > 0 { - sql = fmt.Sprintf("INSERT INTO %s.%s (Host, User, Priv) VALUES ('%s','%s','%s') ON DUPLICATE KEY UPDATE Priv = values(Priv)", - mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, hack.String(privData)) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err = exec.ParseWithParams(context.TODO(), "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } @@ -880,23 +906,25 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.Users { for _, role := range s.Roles { - sql := fmt.Sprintf(`INSERT IGNORE INTO %s.%s (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ('%s','%s','%s','%s')`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `INSERT IGNORE INTO %n.%n (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES (%?,%?,%?,%?)`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -933,10 +961,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } sqlExecutor := sysSession.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { @@ -952,58 +981,66 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } // begin a transaction to delete a user. - sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.global_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } continue } // delete privileges from mysql.db - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.tables_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.role_edges - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE TO_HOST = '%s' and TO_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE TO_HOST = %? and TO_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE FROM_HOST = '%s' and FROM_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE FROM_HOST = %? and FROM_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.default_roles - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE DEFAULT_ROLE_HOST = '%s' and DEFAULT_ROLE_USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE DEFAULT_ROLE_HOST = %? and DEFAULT_ROLE_USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE HOST = '%s' and USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE HOST = %? and USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } @@ -1011,11 +1048,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } if len(failedUsers) == 0 { - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } } else { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } if s.IsDropRole { @@ -1028,8 +1065,12 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } func userExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, name, host) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host) + if err != nil { + return false, err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return false, err } @@ -1062,8 +1103,12 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { } // update mysql.user - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string='%s' WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } diff --git a/executor/slow_query.go b/executor/slow_query.go index b4bd930127ab0..84bdeedf691f6 100755 --- a/executor/slow_query.go +++ b/executor/slow_query.go @@ -457,6 +457,8 @@ type slowQueryTuple struct { rewriteTime float64 preprocSubqueries uint64 preprocSubQueryTime float64 + optimizeTime float64 + waitTSTime float64 preWriteTime float64 waitPrewriteBinlogTime float64 commitTime float64 @@ -496,6 +498,7 @@ type slowQueryTuple struct { isInternal bool succ bool planFromCache bool + planFromBinding bool prepared bool kvTotal float64 pdTotal float64 @@ -563,6 +566,10 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, st.parseTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCompileTimeStr: st.compileTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogOptimizeTimeStr: + st.optimizeTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogWaitTSTimeStr: + st.waitTSTime, err = strconv.ParseFloat(value, 64) case execdetails.PreWriteTimeStr: st.preWriteTime, err = strconv.ParseFloat(value, 64) case execdetails.WaitPrewriteBinlogTimeStr: @@ -635,6 +642,8 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, st.succ, err = strconv.ParseBool(value) case variable.SlowLogPlanFromCache: st.planFromCache, err = strconv.ParseBool(value) + case variable.SlowLogPlanFromBinding: + st.planFromBinding, err = strconv.ParseBool(value) case variable.SlowLogPlan: st.plan = value case variable.SlowLogPlanDigest: @@ -687,6 +696,8 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { record = append(record, types.NewFloat64Datum(st.rewriteTime)) record = append(record, types.NewUintDatum(st.preprocSubqueries)) record = append(record, types.NewFloat64Datum(st.preprocSubQueryTime)) + record = append(record, types.NewFloat64Datum(st.optimizeTime)) + record = append(record, types.NewFloat64Datum(st.waitTSTime)) record = append(record, types.NewFloat64Datum(st.preWriteTime)) record = append(record, types.NewFloat64Datum(st.waitPrewriteBinlogTime)) record = append(record, types.NewFloat64Datum(st.commitTime)) @@ -742,6 +753,11 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { } else { record = append(record, types.NewIntDatum(0)) } + if st.planFromBinding { + record = append(record, types.NewIntDatum(1)) + } else { + record = append(record, types.NewIntDatum(0)) + } record = append(record, types.NewStringDatum(parsePlan(st.plan))) record = append(record, types.NewStringDatum(st.planDigest)) record = append(record, types.NewStringDatum(st.prevStmt)) diff --git a/executor/slow_query_test.go b/executor/slow_query_test.go index 64c63a08e5f94..9d627d26f1687 100644 --- a/executor/slow_query_test.go +++ b/executor/slow_query_test.go @@ -68,6 +68,7 @@ func (s *testExecSuite) TestParseSlowLogPanic(c *C) { # Mem_max: 70724 # Disk_max: 65536 # Plan_from_cache: true +# Plan_from_binding: true # Succ: false # Plan_digest: 60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4 # Prev_stmt: update t set i = 1; @@ -106,6 +107,7 @@ func (s *testExecSuite) TestParseSlowLogFile(c *C) { # Mem_max: 70724 # Disk_max: 65536 # Plan_from_cache: true +# Plan_from_binding: true # Succ: false # Plan_digest: 60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4 # Prev_stmt: update t set i = 1; @@ -130,10 +132,10 @@ select * from t;` } expectRecordString := `2019-04-28 15:24:04.309074,` + `405888132465033227,root,localhost,0,57,0.12,0.216905,` + - `0,0,0,0,0,0,0,0,0,0,,0,0,0,0,0,0,0.38,0.021,0,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,` + + `0,0,0,0,0,0,0,0,0,0,0,0,,0,0,0,0,0,0,0.38,0.021,0,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,` + `0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,65536,0,0,0,0,` + `Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2 Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2 Cop_backoff_rpcTiKV_total_times: 200 Cop_backoff_rpcTiKV_total_time: 0.2 Cop_backoff_rpcTiKV_max_time: 0.2 Cop_backoff_rpcTiKV_max_addr: 127.0.0.1 Cop_backoff_rpcTiKV_avg_time: 0.2 Cop_backoff_rpcTiKV_p90_time: 0.2,` + - `0,0,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + + `0,0,1,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + `update t set i = 1;,select * from t;` c.Assert(recordString, Equals, expectRecordString) diff --git a/executor/testdata/executor_suite_in.json b/executor/testdata/executor_suite_in.json index 960471c958458..cb15278ee9b9f 100644 --- a/executor/testdata/executor_suite_in.json +++ b/executor/testdata/executor_suite_in.json @@ -7,5 +7,17 @@ "select * from t1 natural right join t2 order by a", "SELECT * FROM t1 NATURAL LEFT JOIN t2 WHERE not(t1.a <=> t2.a)" ] + }, + { + "name": "TestIndexScanWithYearCol", + "cases": [ + "select t1.c1, t2.c1 from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select count(*) from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select t1.c1, t2.c1 from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select count(*) from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 is not NULL" + ] } ] diff --git a/executor/testdata/executor_suite_out.json b/executor/testdata/executor_suite_out.json index ef43a96afef6b..bb635bc2951f6 100644 --- a/executor/testdata/executor_suite_out.json +++ b/executor/testdata/executor_suite_out.json @@ -70,5 +70,89 @@ ] } ] + }, + { + "Name": "TestIndexScanWithYearCol", + "Cases": [ + { + "SQL": "select t1.c1, t2.c1 from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_8 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_24(Build) 0.00 root rows:0", + "└─TableDual_23(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select * from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_8 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_26(Build) 0.00 root rows:0", + "└─TableDual_25(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select count(*) from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "StreamAgg_10 1.00 root funcs:count(1)->Column#7", + "└─MergeJoin_11 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + " ├─TableDual_27(Build) 0.00 root rows:0", + " └─TableDual_26(Probe) 0.00 root rows:0" + ], + "Res": [ + "0" + ] + }, + { + "SQL": "select t1.c1, t2.c1 from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_7 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_17(Build) 0.00 root rows:0", + "└─TableDual_16(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_7 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_18(Build) 0.00 root rows:0", + "└─TableDual_17(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select count(*) from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "StreamAgg_9 1.00 root funcs:count(1)->Column#7", + "└─MergeJoin_10 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + " ├─TableDual_20(Build) 0.00 root rows:0", + " └─TableDual_19(Probe) 0.00 root rows:0" + ], + "Res": [ + "0" + ] + }, + { + "SQL": "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 is not NULL", + "Plan": [ + "HashJoin_15 12487.50 root left outer join, equal:[eq(test.t.c1, test.t.c1)]", + "├─TableReader_33(Build) 9990.00 root data:Selection_32", + "│ └─Selection_32 9990.00 cop[tikv] not(isnull(test.t.c1))", + "│ └─TableFullScan_31 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader_27(Probe) 9990.00 root data:Selection_26", + " └─Selection_26 9990.00 cop[tikv] not(isnull(test.t.c1))", + " └─TableFullScan_25 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Res": [ + "2001 1 2001 1" + ] + } + ] } ] diff --git a/executor/trace.go b/executor/trace.go index bf9150f357081..8eb6592c4822f 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -132,24 +132,21 @@ func (e *TraceExec) nextRowJSON(ctx context.Context, se sqlexec.SQLExecutor, req } func (e *TraceExec) executeChild(ctx context.Context, se sqlexec.SQLExecutor) { - recordSets, err := se.Execute(ctx, e.stmtNode.Text()) - if len(recordSets) == 0 { - if err != nil { - var errCode uint16 - if te, ok := err.(*terror.Error); ok { - errCode = terror.ToSQLError(te).Code - } - logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) - } else { - logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) + rs, err := se.ExecuteStmt(ctx, e.stmtNode) + if err != nil { + var errCode uint16 + if te, ok := err.(*terror.Error); ok { + errCode = terror.ToSQLError(te).Code } + logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) } - for _, rs := range recordSets { + if rs != nil { drainRecordSet(ctx, e.ctx, rs) if err = rs.Close(); err != nil { logutil.Logger(ctx).Error("run trace close result with error", zap.Error(err)) } } + logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) } func drainRecordSet(ctx context.Context, sctx sessionctx.Context, rs sqlexec.RecordSet) { diff --git a/executor/update_test.go b/executor/update_test.go index 1e6559dcf9ccd..cd28590a97d4a 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -259,3 +259,12 @@ func (s *testPointGetSuite) TestIssue21447(c *C) { tk1.MustQuery("select * from t1 where id in (1, 2) for update").Check(testkit.Rows("1 xyz")) tk1.MustExec("commit") } + +func (s *testPointGetSuite) TestIssue23553(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`drop table if exists tt`) + tk.MustExec(`create table tt (m0 varchar(64), status tinyint not null)`) + tk.MustExec(`insert into tt values('1',0),('1',0),('1',0)`) + tk.MustExec(`update tt a inner join (select m0 from tt where status!=1 group by m0 having count(*)>1) b on a.m0=b.m0 set a.status=1`) +} diff --git a/executor/utils.go b/executor/utils.go new file mode 100644 index 0000000000000..fbc9ab4dcff30 --- /dev/null +++ b/executor/utils.go @@ -0,0 +1,46 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import "strings" + +func setFromString(value string) []string { + if len(value) == 0 { + return nil + } + return strings.Split(value, ",") +} + +// addToSet add a value to the set, e.g: +// addToSet("Select,Insert,Update", "Update") returns "Select,Insert,Update". +func addToSet(set []string, value string) []string { + for _, v := range set { + if v == value { + return set + } + } + return append(set, value) +} + +// deleteFromSet delete the value from the set, e.g: +// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". +func deleteFromSet(set []string, value string) []string { + for i, v := range set { + if v == value { + copy(set[i:], set[i+1:]) + return set[:len(set)-1] + } + } + return set +} diff --git a/executor/write.go b/executor/write.go index d607167ac4fe7..06fed54c90e36 100644 --- a/executor/write.go +++ b/executor/write.go @@ -186,6 +186,14 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h int64, oldData } return false, handleChanged, newHandle, err } + err = tables.CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore(ctx, sctx, t, newHandle, newData, modified) + if err != nil { + if terr, ok := errors.Cause(err).(*terror.Error); sctx.GetSessionVars().StmtCtx.IgnoreNoPartition && ok && terr.Code() == errno.ErrNoPartitionForGivenValue { + //return false, nil + return false, handleChanged, newHandle, nil + } + return false, handleChanged, newHandle, err + } } if err = t.RemoveRecord(sctx, h, oldData); err != nil { return false, false, 0, err diff --git a/executor/write_test.go b/executor/write_test.go index 4120fafea4a67..1dc1c9b389ae9 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -761,6 +761,28 @@ func (s *testSuite4) TestInsertIgnoreOnDup(c *C) { testSQL = `select * from t;` r = tk.MustQuery(testSQL) r.Check(testkit.Rows("1 1", "2 2")) + + tk.MustExec("drop table if exists t4") + tk.MustExec("create table t4(id int primary key, k int, v int, unique key uk1(k))") + tk.MustExec("insert into t4 values (1, 10, 100), (3, 30, 300)") + tk.MustExec("insert ignore into t4 (id, k, v) values(1, 0, 0) on duplicate key update id = 2, k = 30") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '30' for key 'uk1'")) + tk.MustQuery("select * from t4").Check(testkit.Rows("1 10 100", "3 30 300")) + + tk.MustExec("drop table if exists t5") + tk.MustExec("create table t5(k2 int primary key, uk1 int, v int, unique key ukk1(uk1), unique key ukk2(v))") + tk.MustExec("insert into t5(k2, uk1, v) values(1, 1, '100'), (3, 2, '200')") + tk.MustExec("update ignore t5 set k2 = '2', uk1 = 2 where k2 = 1") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '2' for key 'ukk1'")) + tk.MustQuery("select * from t5").Check(testkit.Rows("1 1 100", "3 2 200")) + + tk.MustExec("drop table if exists t6") + tk.MustExec("create table t6 (a int, b int, c int, primary key(a), unique key idx_14(b), unique key idx_15(b), unique key idx_16(a, b))") + tk.MustExec("insert into t6 select 10, 10, 20") + tk.MustExec("insert ignore into t6 set a = 20, b = 10 on duplicate key update a = 100") + tk.MustQuery("select * from t6").Check(testkit.Rows("100 10 20")) + tk.MustExec("insert ignore into t6 set a = 200, b= 10 on duplicate key update c = 1000") + tk.MustQuery("select * from t6").Check(testkit.Rows("100 10 1000")) } func (s *testSuite4) TestInsertSetWithDefault(c *C) { diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 1c54dec75b6e5..1a1a7543ef29a 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -212,7 +212,12 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { // Because child returns integer or decimal type. func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { switch a.Args[0].GetType().Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeNewDecimal: + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + a.RetTp.Decimal = types.DivFracIncr + flen, _ := mysql.GetDefaultFieldLengthAndDecimal(a.Args[0].GetType().Tp) + a.RetTp.Flen = flen + types.DivFracIncr + case mysql.TypeYear, mysql.TypeNewDecimal: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) if a.Args[0].GetType().Decimal < 0 { a.RetTp.Decimal = mysql.MaxDecimalScale diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 3f21d4303c448..e10a7f5d43c74 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1899,7 +1899,11 @@ func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression { argLen = -1 } tp := types.NewFieldType(mysql.TypeVarString) - tp.Charset, tp.Collate = expr.CharsetAndCollation(ctx) + if expr.Coercibility() == CoercibilityExplicit { + tp.Charset, tp.Collate = expr.CharsetAndCollation(ctx) + } else { + tp.Charset, tp.Collate = ctx.GetSessionVars().GetCharsetInfo() + } tp.Flen, tp.Decimal = argLen, types.UnspecifiedLength return BuildCastFunction(ctx, expr, tp) } diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 671158a5e0138..5fe32e958249f 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -385,14 +385,18 @@ func ResolveType4Between(args [3]Expression) types.EvalType { hasTemporal := false if cmpTp == types.ETString { - for _, arg := range args { - if types.IsTypeTemporal(arg.GetType().Tp) { - hasTemporal = true - break + if args[0].GetType().Tp == mysql.TypeDuration { + cmpTp = types.ETDuration + } else { + for _, arg := range args { + if types.IsTypeTemporal(arg.GetType().Tp) { + hasTemporal = true + break + } + } + if hasTemporal { + cmpTp = types.ETDatetime } - } - if hasTemporal { - cmpTp = types.ETDatetime } } @@ -1305,15 +1309,15 @@ func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldT // We try to convert the string constant to double. // If the double result equals the int result, we can return the int result; // otherwise, the compare function will be false. + // **Notice** + // we can not compare double result to int result directly, because year type will change its value, like + // 2 to 2002, here we just check whether double value equal int(double value). We can assert the int(string) var doubleDatum types.Datum doubleDatum, err = dt.ConvertTo(sc, types.NewFieldType(mysql.TypeDouble)) if err != nil { return con, false } - if c, err = doubleDatum.CompareDatum(sc, &intDatum); err != nil { - return con, false - } - if c != 0 { + if doubleDatum.GetFloat64() != math.Trunc(doubleDatum.GetFloat64()) { return con, true } return &Constant{ @@ -1344,8 +1348,13 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express // int non-constant [cmp] non-int constant if arg0IsInt && !arg0IsCon && !arg1IsInt && arg1IsCon { arg1, isExceptional = RefineComparedConstant(ctx, *arg0Type, arg1, c.op) - finalArg1 = arg1 - if isExceptional && arg1.GetType().EvalType() == types.ETInt { + // Why check not null flag + // eg: int_col > const_val(which is less than min_int32) + // If int_col got null, compare result cannot be true + if !isExceptional || (isExceptional && mysql.HasNotNullFlag(arg0Type.Flag)) { + finalArg1 = arg1 + } + if isExceptional && arg1.GetType().EvalType() == types.ETInt && mysql.HasNotNullFlag(arg0Type.Flag) { // Judge it is inf or -inf // For int: // inf: 01111111 & 1 == 1 @@ -1363,8 +1372,10 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express // non-int constant [cmp] int non-constant if arg1IsInt && !arg1IsCon && !arg0IsInt && arg0IsCon { arg0, isExceptional = RefineComparedConstant(ctx, *arg1Type, arg0, symmetricOp[c.op]) - finalArg0 = arg0 - if isExceptional && arg0.GetType().EvalType() == types.ETInt { + if !isExceptional || (isExceptional && mysql.HasNotNullFlag(arg1Type.Flag)) { + finalArg0 = arg0 + } + if isExceptional && arg0.GetType().EvalType() == types.ETInt && mysql.HasNotNullFlag(arg1Type.Flag) { if arg0.Value.GetInt64()&1 == 1 { isNegativeInfinite = true } else { @@ -1373,7 +1384,7 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express } } // int constant [cmp] year type - if arg0IsCon && arg0IsInt && arg1Type.Tp == mysql.TypeYear { + if arg0IsCon && arg0IsInt && arg1Type.Tp == mysql.TypeYear && !arg0.Value.IsNull() { adjusted, failed := types.AdjustYear(arg0.Value.GetInt64(), false) if failed == nil { arg0.Value.SetInt64(adjusted) @@ -1381,7 +1392,7 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express } } // year type [cmp] int constant - if arg1IsCon && arg1IsInt && arg0Type.Tp == mysql.TypeYear { + if arg1IsCon && arg1IsInt && arg0Type.Tp == mysql.TypeYear && !arg1.Value.IsNull() { adjusted, failed := types.AdjustYear(arg1.Value.GetInt64(), false) if failed == nil { arg1.Value.SetInt64(adjusted) diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index f8b8e6f08843c..1e3908b856030 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -27,7 +27,7 @@ import ( ) func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) { - tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong).build() + tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build() tests := []struct { exprStr string result string diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 69b61323bc535..2f0b409e6f5e7 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -67,6 +67,8 @@ func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { resultFieldType := &types.FieldType{} if lhs.Tp == mysql.TypeNull { *resultFieldType = *rhs + // If any of arg is NULL, result type need unset NotNullFlag. + types.SetTypeFlag(&resultFieldType.Flag, mysql.NotNullFlag, false) // If both arguments are NULL, make resulting type BINARY(0). if rhs.Tp == mysql.TypeNull { resultFieldType.Tp = mysql.TypeString @@ -75,6 +77,7 @@ func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { } } else if rhs.Tp == mysql.TypeNull { *resultFieldType = *lhs + types.SetTypeFlag(&resultFieldType.Flag, mysql.NotNullFlag, false) } else { resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs}) evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &resultFieldType.Flag) @@ -173,6 +176,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } fieldTp := types.AggFieldType(fieldTps) + // Here we turn off NotNullFlag. Because if all when-clauses are false, + // the result of case-when expr is NULL. + types.SetTypeFlag(&fieldTp.Flag, mysql.NotNullFlag, false) tp := fieldTp.EvalType() if tp == types.ETInt { diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 6ea1655e1c874..7f6e35aaa8626 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -116,7 +116,7 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) { {tm, nil, tm, false, false}, {nil, duration, duration, false, false}, {nil, types.NewDecFromFloatForTest(123.123), types.NewDecFromFloatForTest(123.123), false, false}, - {nil, types.NewBinaryLiteralFromUint(0x01, -1), uint64(1), false, false}, + {nil, types.NewBinaryLiteralFromUint(0x01, -1), "\x01", false, false}, {nil, types.Set{Value: 1, Name: "abc"}, "abc", false, false}, {nil, jsonInt.GetMysqlJSON(), jsonInt.GetMysqlJSON(), false, false}, {"abc", nil, "abc", false, false}, diff --git a/expression/builtin_other.go b/expression/builtin_other.go index 513dab9f206ad..faa41cd0a256a 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -155,8 +155,10 @@ func (c *inFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) type baseInSig struct { baseBuiltinFunc - nonConstArgs []Expression - hasNull bool + // nonConstArgsIdx stores the indices of non-constant args in the baseBuiltinFunc.args (the first arg is not included). + // It works with builtinInXXXSig.hashset to accelerate 'eval'. + nonConstArgsIdx []int + hasNull bool } // builtinInIntSig see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_in @@ -167,7 +169,7 @@ type builtinInIntSig struct { } func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[int64]bool, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -181,7 +183,7 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error } b.hashSet[val] = mysql.HasUnsignedFlag(b.args[i].GetType().Flag) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } return nil @@ -190,10 +192,8 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error func (b *builtinInIntSig) Clone() builtinFunc { newSig := &builtinInIntSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -206,9 +206,8 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) { } isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if isUnsigned, ok := b.hashSet[arg0]; ok { if (isUnsigned0 && isUnsigned) || (!isUnsigned0 && !isUnsigned) { return 1, false, nil @@ -217,10 +216,14 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) { return 1, false, nil } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalInt(b.ctx, row) if err != nil { return 0, true, err @@ -258,7 +261,7 @@ type builtinInStringSig struct { } func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewStringSet() collator := collate.GetCollator(b.collation) for i := 1; i < len(b.args); i++ { @@ -273,7 +276,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er } b.hashSet.Insert(string(collator.Key(val))) // should do memory copy here } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -283,10 +286,8 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er func (b *builtinInStringSig) Clone() builtinFunc { newSig := &builtinInStringSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -298,17 +299,20 @@ func (b *builtinInStringSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull0, err } - args := b.args + args := b.args[1:] collator := collate.GetCollator(b.collation) if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(string(collator.Key(arg0))) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalString(b.ctx, row) if err != nil { return 0, true, err @@ -331,7 +335,7 @@ type builtinInRealSig struct { } func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewFloat64Set() for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -345,7 +349,7 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro } b.hashSet.Insert(val) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -355,10 +359,8 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro func (b *builtinInRealSig) Clone() builtinFunc { newSig := &builtinInRealSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -369,15 +371,19 @@ func (b *builtinInRealSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(arg0) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalReal(b.ctx, row) if err != nil { return 0, true, err @@ -400,7 +406,7 @@ type builtinInDecimalSig struct { } func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewStringSet() for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -418,7 +424,7 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e } b.hashSet.Insert(string(key)) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -428,10 +434,8 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e func (b *builtinInDecimalSig) Clone() builtinFunc { newSig := &builtinInDecimalSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -443,20 +447,23 @@ func (b *builtinInDecimalSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull0, err } - args := b.args + args := b.args[1:] key, err := arg0.ToHashKey() if err != nil { return 0, true, err } if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(string(key)) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalDecimal(b.ctx, row) if err != nil { return 0, true, err @@ -479,7 +486,7 @@ type builtinInTimeSig struct { } func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[types.CoreTime]struct{}, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -493,7 +500,7 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro } b.hashSet[val.CoreTime()] = struct{}{} } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -503,10 +510,8 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro func (b *builtinInTimeSig) Clone() builtinFunc { newSig := &builtinInTimeSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -517,15 +522,19 @@ func (b *builtinInTimeSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if _, ok := b.hashSet[arg0.CoreTime()]; ok { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalTime(b.ctx, row) if err != nil { return 0, true, err @@ -548,7 +557,7 @@ type builtinInDurationSig struct { } func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[time.Duration]struct{}, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -562,7 +571,7 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) } b.hashSet[val.Duration] = struct{}{} } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -572,10 +581,8 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) func (b *builtinInDurationSig) Clone() builtinFunc { newSig := &builtinInDurationSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -586,15 +593,19 @@ func (b *builtinInDurationSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if _, ok := b.hashSet[arg0.Duration]; ok { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalDuration(b.ctx, row) if err != nil { return 0, true, err diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index e44f2f6759ca0..0cdf900f0793d 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -53,9 +53,8 @@ func (b *builtinInIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e } isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -73,9 +72,13 @@ func (b *builtinInIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e } } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalInt(b.ctx, input, buf1); err != nil { return err } @@ -153,10 +156,9 @@ func (b *builtinInStringSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { collator := collate.GetCollator(b.collation) - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -168,9 +170,13 @@ func (b *builtinInStringSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalString(b.ctx, input, buf1); err != nil { return err } @@ -232,9 +238,8 @@ func (b *builtinInDecimalSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -250,9 +255,13 @@ func (b *builtinInDecimalSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalDecimal(b.ctx, input, buf1); err != nil { return err } @@ -319,9 +328,8 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -333,9 +341,13 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalReal(b.ctx, input, buf1); err != nil { return err } @@ -399,9 +411,8 @@ func (b *builtinInTimeSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -413,9 +424,13 @@ func (b *builtinInTimeSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalTime(b.ctx, input, buf1); err != nil { return err } @@ -479,9 +494,8 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -493,9 +507,13 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalDuration(b.ctx, input, buf1); err != nil { return err } @@ -553,9 +571,9 @@ func (b *builtinInJSONSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } hasNull := make([]bool, n) var compareResult int - args := b.args + args := b.args[1:] - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalJSON(b.ctx, input, buf1); err != nil { return err } diff --git a/expression/builtin_string.go b/expression/builtin_string.go index d07536c42bee5..749d3453d0775 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -180,6 +180,11 @@ func SetBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) { } } +// addBinFlag add the binary flag to `tp` if its charset is binary +func addBinFlag(tp *types.FieldType) { + SetBinFlagOrBinStr(tp, tp) +} + type lengthFunctionClass struct { baseFunctionClass } @@ -275,10 +280,10 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if err != nil { return nil, err } + addBinFlag(bf.tp) bf.tp.Flen = 0 for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) if argType.Flen < 0 { bf.tp.Flen = mysql.MaxBlobWidth @@ -350,9 +355,9 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } bf.tp.Flen = 0 + addBinFlag(bf.tp) for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) // skip separator param if i != 0 { @@ -2000,8 +2005,7 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2133,8 +2137,7 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2668,9 +2671,7 @@ func (c *makeSetFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } - for i, length := 0, len(args); i < length; i++ { - SetBinFlagOrBinStr(args[i].GetType(), bf.tp) - } + addBinFlag(bf.tp) bf.tp.Flen = c.getFlen(bf.ctx, args) if bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth @@ -3589,8 +3590,7 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express return nil, err } bf.tp.Flen = mysql.MaxBlobWidth - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[3].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 8821133082836..67d81a0ae277c 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1149,7 +1149,7 @@ func (s *testEvaluatorSuite) TestHexFunc(c *C) { {-1, false, false, "FFFFFFFFFFFFFFFF"}, {-12.3, false, false, "FFFFFFFFFFFFFFF4"}, {-12.8, false, false, "FFFFFFFFFFFFFFF3"}, - {types.NewBinaryLiteralFromUint(0xC, -1), false, false, "C"}, + {types.NewBinaryLiteralFromUint(0xC, -1), false, false, "0C"}, {0x12, false, false, "12"}, {nil, true, false, ""}, {errors.New("must err"), false, true, ""}, diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 3ee725e78d578..16fcaa8fbfaa9 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1411,6 +1411,15 @@ func (s *testEvaluatorSuite) TestStrToDate(c *C) { {"15-01-2001 1:9:8.999", "%d-%m-%Y %H:%i:%S.%f", true, time.Date(2001, 1, 15, 1, 9, 8, 999000000, time.Local)}, {"2003-01-02 10:11:12 PM", "%Y-%m-%d %H:%i:%S %p", false, time.Time{}}, {"10:20:10AM", "%H:%i:%S%p", false, time.Time{}}, + // test %@(skip alpha), %#(skip number), %.(skip punct) + {"2020-10-10ABCD", "%Y-%m-%d%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-101234", "%Y-%m-%d%#", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10....", "%Y-%m-%d%.", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10.1", "%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"abcd2020-10-10.1", "%@%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"abcd-2020-10-10.1", "%@-%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10", "%Y-%m-%d%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10abcde123abcdef", "%Y-%m-%d%@%#", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, } fc := funcs[ast.StrToDate] diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index f5e8bca8a30fd..7c521d7ea37ae 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -532,6 +532,9 @@ func (s *propOuterJoinConstSolver) deriveConds(outerCol, innerCol *Column, schem // 'expression(..., outerCol, ...)' does not reference columns outside children schemas of join node. // Derived new expressions must be appended into join condition, not filter condition. func (s *propOuterJoinConstSolver) propagateColumnEQ() { + if s.nullSensitive { + return + } visited := make([]bool, 2*len(s.joinConds)+len(s.filterConds)) s.unionSet = disjointset.NewIntSet(len(s.columns)) var outerCol, innerCol *Column @@ -552,9 +555,6 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() { // `select *, t1.a in (select t2.b from t t2) from t t1` // rows with t2.b is null would impact whether LeftOuterSemiJoin should output 0 or null if there // is no row satisfying t2.b = t1.a - if s.nullSensitive { - continue - } childCol := s.innerSchema.RetrieveColumn(innerCol) if !mysql.HasNotNullFlag(childCol.RetType.Flag) { notNullExpr := BuildNotNullExpr(s.ctx, childCol) diff --git a/expression/constant_test.go b/expression/constant_test.go index 50ea87d010f7b..49929497069cf 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -271,8 +271,8 @@ func (*testExpressionSuite) TestDeferredParamNotNull(c *C) { c.Assert(mysql.TypeTimestamp, Equals, cstTime.GetType().Tp) c.Assert(mysql.TypeDuration, Equals, cstDuration.GetType().Tp) c.Assert(mysql.TypeBlob, Equals, cstBytes.GetType().Tp) - c.Assert(mysql.TypeBit, Equals, cstBinary.GetType().Tp) - c.Assert(mysql.TypeBit, Equals, cstBit.GetType().Tp) + c.Assert(mysql.TypeVarString, Equals, cstBinary.GetType().Tp) + c.Assert(mysql.TypeVarString, Equals, cstBit.GetType().Tp) c.Assert(mysql.TypeFloat, Equals, cstFloat32.GetType().Tp) c.Assert(mysql.TypeDouble, Equals, cstFloat64.GetType().Tp) c.Assert(mysql.TypeEnum, Equals, cstEnum.GetType().Tp) diff --git a/expression/expression.go b/expression/expression.go index f5fcbbf44d95f..602383d1e76c1 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1143,6 +1143,13 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType // Check whether this function can be pushed. if !canFuncBePushed(scalarFunc, storeType) { + if pc.sc.InExplainStmt { + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + pc.sc.AppendWarning(errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ") can not be pushed to " + storageName)) + } return false } @@ -1166,6 +1173,9 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType func canExprPushDown(expr Expression, pc PbConverter, storeType kv.StoreType) bool { if storeType == kv.TiFlash && expr.GetType().Tp == mysql.TypeDuration { + if pc.sc.InExplainStmt { + pc.sc.AppendWarning(errors.New("Expr '" + expr.String() + "' can not be pushed to TiFlash because it contains Duration type")) + } return false } switch x := expr.(type) { diff --git a/expression/expression_test.go b/expression/expression_test.go index 2ecdc12f465eb..40eed2c946207 100644 --- a/expression/expression_test.go +++ b/expression/expression_test.go @@ -35,7 +35,7 @@ func (s *testEvaluatorSuite) TestNewValuesFunc(c *C) { } func (s *testEvaluatorSuite) TestEvaluateExprWithNull(c *C) { - tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong).add("col1", mysql.TypeLonglong).build() + tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong, 0).add("col1", mysql.TypeLonglong, 0).build() schema := tableInfoToSchemaForTest(tblInfo) col0 := schema.Columns[0] col1 := schema.Columns[1] @@ -142,15 +142,17 @@ type testTableBuilder struct { tableName string columnNames []string tps []byte + flags []uint } func newTestTableBuilder(tableName string) *testTableBuilder { return &testTableBuilder{tableName: tableName} } -func (builder *testTableBuilder) add(name string, tp byte) *testTableBuilder { +func (builder *testTableBuilder) add(name string, tp byte, flag uint) *testTableBuilder { builder.columnNames = append(builder.columnNames, name) builder.tps = append(builder.tps, tp) + builder.flags = append(builder.flags, flag) return builder } @@ -165,6 +167,7 @@ func (builder *testTableBuilder) build() *model.TableInfo { fieldType := types.NewFieldType(tp) fieldType.Flen, fieldType.Decimal = mysql.GetDefaultFieldLengthAndDecimal(tp) fieldType.Charset, fieldType.Collate = types.DefaultCharsetForType(tp) + fieldType.Flag = builder.flags[i] ti.Columns = append(ti.Columns, &model.ColumnInfo{ ID: int64(i + 1), Name: model.NewCIStr(colName), diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index e04c629a53c3b..a1d574c80b901 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -144,13 +144,12 @@ func (b *{{.SigName}}) vecEvalInt(input *chunk.Chunk, result *chunk.Column) erro isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) {{- end }} var compareResult int - args := b.args + args := b.args[1:] {{- if not $InputJSON}} if len(b.hashSet) != 0 { {{- if $InputString }} collator := collate.GetCollator(b.collation) {{- end }} - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -202,10 +201,14 @@ func (b *{{.SigName}}) vecEvalInt(input *chunk.Chunk, result *chunk.Column) erro {{- end }} {{- end }} } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } {{- end }} - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEval{{ .Input.TypeName }}(b.ctx, input, buf1); err != nil { return err } diff --git a/expression/integration_test.go b/expression/integration_test.go index 623f72a1e6539..117e073609269 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2920,6 +2920,16 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { result.Check(testkit.Rows(" 4")) result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end") result.Check(testkit.Rows(" 4")) + + // return type of case when expr should not include NotNullFlag. issue-23036 + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int not null)") + tk.MustExec("insert into t1 values(1)") + result = tk.MustQuery("select (case when null then c1 end) is null from t1") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("select (case when null then c1 end) is not null from t1") + result.Check(testkit.Rows("0")) + // test warnings tk.MustQuery("select case when b=0 then 1 else 1/b end from t") tk.MustQuery("show warnings").Check(testkit.Rows()) @@ -3723,6 +3733,17 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) { result.Check(testkit.Rows("0")) } +// #23157: make sure if Nullif expr is correct combined with IsNull expr. +func (s *testIntegrationSuite) TestNullifWithIsNull(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null);") + tk.MustExec("insert into t values(1),(2);") + rows := tk.MustQuery("select * from t where nullif(a,a) is null;") + rows.Check(testkit.Rows("1", "2")) +} + func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) @@ -6076,6 +6097,15 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) { tk.MustQuery("select * from t_ci where a='A'").Check(testkit.Rows("a")) tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c set('A', 'B') collate utf8mb4_general_ci);") + tk.MustExec("insert into t values('a');") + tk.MustExec("insert into t values('B');") + tk.MustQuery("select c from t where c = 'a';").Check(testkit.Rows("A")) + tk.MustQuery("select c from t where c = 'A';").Check(testkit.Rows("A")) + tk.MustQuery("select c from t where c = 'b';").Check(testkit.Rows("B")) + tk.MustQuery("select c from t where c = 'B';").Check(testkit.Rows("B")) } func (s *testIntegrationSerialSuite) TestWeightString(c *C) { @@ -7207,6 +7237,20 @@ func (s *testIntegrationSerialSuite) TestIssue18702(c *C) { tk.MustQuery("SELECT * FROM t FORCE INDEX(idx_bc);").Check(testkit.Rows("1 A 10 1", "2 B 20 1")) } +func (s *testIntegrationSerialSuite) TestCollationText(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a TINYTEXT collate UTF8MB4_GENERAL_CI, UNIQUE KEY `a`(`a`(10)));") + tk.MustExec("insert into t (a) values ('A');") + tk.MustQuery("select * from t t1 inner join t t2 on t1.a = t2.a where t1.a = 'A';").Check(testkit.Rows("A A")) + tk.MustExec("update t set a = 'B';") + tk.MustExec("admin check table t;") +} + func (s *testIntegrationSuite) TestIssue18850(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -8010,6 +8054,16 @@ func (s *testSuite2) TestIssue12205(c *C) { testkit.Rows("Warning 1292 Truncated incorrect time value: '18446744072635875000'")) } +func (s *testIntegrationSuite) Test23262(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a year)") + tk.MustExec("insert into t values(2002)") + tk.MustQuery("select * from t where a=2").Check(testkit.Rows("2002")) + tk.MustQuery("select * from t where a='2'").Check(testkit.Rows("2002")) +} + func (s *testIntegrationSuite) TestIssue11333(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) @@ -8032,6 +8086,7 @@ func (s *testIntegrationSerialSuite) TestIssue19116(c *C) { defer collate.SetNewCollationEnabledForTest(false) tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") tk.MustQuery("select collation(concat(1 collate `binary`));").Check(testkit.Rows("binary")) tk.MustQuery("select coercibility(concat(1 collate `binary`));").Check(testkit.Rows("0")) @@ -8042,4 +8097,54 @@ func (s *testIntegrationSerialSuite) TestIssue19116(c *C) { tk.MustQuery("select collation(1);").Check(testkit.Rows("binary")) tk.MustQuery("select coercibility(1);").Check(testkit.Rows("5")) tk.MustQuery("select coercibility(1=1);").Check(testkit.Rows("5")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a datetime)") + tk.MustExec("insert into t values ('2020-02-02')") + tk.MustQuery("select collation(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select coercibility(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("4")) +} + +func (s *testIntegrationSuite) TestApproximatePercentile(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a bit(10))") + tk.MustExec("insert into t values(b'1111')") + tk.MustQuery("select approx_percentile(a, 10) from t").Check(testkit.Rows("")) +} + +func (s *testIntegrationSuite) TestIssue23889(c *C) { + defer s.cleanEnv(c) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists test_decimal,test_t;") + tk.MustExec("create table test_decimal(col_decimal decimal(10,0));") + tk.MustExec("insert into test_decimal values(null),(8);") + tk.MustExec("create table test_t(a int(11), b decimal(32,0));") + tk.MustExec("insert into test_t values(1,4),(2,4),(5,4),(7,4),(9,4);") + + tk.MustQuery("SELECT ( test_decimal . `col_decimal` , test_decimal . `col_decimal` ) IN ( select * from test_t ) as field1 FROM test_decimal;").Check( + testkit.Rows("", "0")) +} + +func (s *testIntegrationSuite) TestIssue23623(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(-2147483648), (-2147483648), (null);") + tk.MustQuery("select count(*) from t1 where c1 > (select sum(c1) from t1);").Check(testkit.Rows("2")) +} + +func (s *testIntegrationSerialSuite) TestCollationForBinaryLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (`COL1` tinyblob NOT NULL, `COL2` binary(1) NOT NULL, `COL3` bigint(11) NOT NULL, PRIMARY KEY (`COL1`(5),`COL2`,`COL3`) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("insert into t values(0x1E,0xEC,6966939640596047133);") + tk.MustQuery("select * from t where col1 not in (0x1B,0x20) order by col1").Check(testkit.Rows("\x1e \xec 6966939640596047133")) + tk.MustExec("drop table t") } diff --git a/expression/partition_pruner_test.go b/expression/partition_pruner_test.go index a132b6b84377b..d54f3b3e5f05d 100644 --- a/expression/partition_pruner_test.go +++ b/expression/partition_pruner_test.go @@ -86,3 +86,14 @@ func (s *testSuite2) TestHashPartitionPruner(c *C) { tk.MustQuery(tt).Check(testkit.Rows(output[i].Result...)) } } + +func (s *testSuite2) TestIssue22898(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("DROP TABLE IF EXISTS test;") + tk.MustExec("CREATE TABLE NT_RP3763 (COL1 TINYINT(8) SIGNED COMMENT \"NUMERIC NO INDEX\" DEFAULT 41,COL2 VARCHAR(20),COL3 DATETIME,COL4 BIGINT,COL5 FLOAT) PARTITION BY RANGE (COL1 * COL3) (PARTITION P0 VALUES LESS THAN (0),PARTITION P1 VALUES LESS THAN (10),PARTITION P2 VALUES LESS THAN (20),PARTITION P3 VALUES LESS THAN (30),PARTITION P4 VALUES LESS THAN (40),PARTITION P5 VALUES LESS THAN (50),PARTITION PMX VALUES LESS THAN MAXVALUE);") + tk.MustExec("insert into NT_RP3763 (COL1,COL2,COL3,COL4,COL5) values(-82,\"夐齏醕皆磹漋甓崘潮嵙燷渏艂朼洛炷鉢儝鱈肇\",\"5748\\-06\\-26\\ 20:48:49\",-3133527360541070260,-2.624880003397658e+38);") + tk.MustExec("insert into NT_RP3763 (COL1,COL2,COL3,COL4,COL5) values(48,\"簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢\",\"7228\\-12\\-13\\ 02:59:54\",-6181009269190017937,2.7731105531290494e+38);") + tk.MustQuery("select * from `NT_RP3763` where `COL1` in (10, 48, -82);").Check(testkit.Rows("-82 夐齏醕皆磹漋甓崘潮嵙燷渏艂朼洛炷鉢儝鱈肇 5748-06-26 20:48:49 -3133527360541070260 -262488000000000000000000000000000000000", "48 簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢 7228-12-13 02:59:54 -6181009269190017937 277311060000000000000000000000000000000")) + tk.MustQuery("select * from `NT_RP3763` where `COL1` in (48);").Check(testkit.Rows("48 簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢 7228-12-13 02:59:54 -6181009269190017937 277311060000000000000000000000000000000")) +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 9137e922df5e9..7b950813e8d3f 100755 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -170,6 +170,7 @@ func typeInferForNull(args []Expression) { for _, arg := range args { if isNull(arg) { *arg.GetType() = *retFieldTp + arg.GetType().Flag &= ^mysql.NotNullFlag // Remove NotNullFlag of NullConst } } } @@ -428,29 +429,6 @@ func (sf *ScalarFunction) ResolveIndices(schema *Schema) (Expression, error) { } func (sf *ScalarFunction) resolveIndices(schema *Schema) error { - if sf.FuncName.L == ast.In { - args := []Expression{} - switch inFunc := sf.Function.(type) { - case *builtinInIntSig: - args = inFunc.nonConstArgs - case *builtinInStringSig: - args = inFunc.nonConstArgs - case *builtinInTimeSig: - args = inFunc.nonConstArgs - case *builtinInDurationSig: - args = inFunc.nonConstArgs - case *builtinInRealSig: - args = inFunc.nonConstArgs - case *builtinInDecimalSig: - args = inFunc.nonConstArgs - } - for _, arg := range args { - err := arg.resolveIndices(schema) - if err != nil { - return err - } - } - } for _, arg := range sf.GetArgs() { err := arg.resolveIndices(schema) if err != nil { diff --git a/expression/scalar_function_test.go b/expression/scalar_function_test.go index f3349a87c34fb..c5f52e3309532 100755 --- a/expression/scalar_function_test.go +++ b/expression/scalar_function_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" ) func (s *testEvaluatorSuite) TestScalarFunction(c *C) { @@ -47,6 +48,21 @@ func (s *testEvaluatorSuite) TestScalarFunction(c *C) { c.Assert(ok, IsTrue) } +func (s *testEvaluatorSuite) TestIssue23309(c *C) { + a := &Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeDouble), + } + a.RetType.Flag |= mysql.NotNullFlag + null := NewNull() + null.RetType = types.NewFieldType(mysql.TypeNull) + sf, _ := newFunction(ast.NE, a, null).(*ScalarFunction) + v, err := sf.GetArgs()[1].Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) + c.Assert(mysql.HasNotNullFlag(sf.GetArgs()[1].GetType().Flag), IsFalse) +} + func (s *testEvaluatorSuite) TestScalarFuncs2Exprs(c *C) { a := &Column{ UniqueID: 1, diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 8a46b32664c24..7cafe7fb1c46b 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -240,8 +240,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"strcmp(c_char, c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"space(c_int_d)", mysql.TypeLongBlob, mysql.DefaultCharset, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"CONCAT(c_binary, c_int_d)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 40, types.UnspecifiedLength}, - {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 40, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 4, types.UnspecifiedLength}, + {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 6, types.UnspecifiedLength}, {"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength}, @@ -451,8 +452,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"find_in_set(c_set , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, {"find_in_set(c_enum , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, - {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 65535, types.UnspecifiedLength}, + {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"make_set(c_bigint_d , c_text_d, c_binary)", mysql.TypeMediumBlob, charset.CharsetBin, mysql.BinaryFlag, 65556, types.UnspecifiedLength}, + {"make_set(1 , c_text_d, 0x40)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"quote(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, {"quote(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, @@ -465,6 +467,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"convert(c_text_d using 'binary')", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_varchar, c_int_d, c_int_d, 0x40)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, diff --git a/go.mod b/go.mod index e967e7cbdab41..fcf11d9da816a 100644 --- a/go.mod +++ b/go.mod @@ -44,9 +44,9 @@ require ( github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 - github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc + github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8 - github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe + github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966 github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible github.com/pingcap/tipb v0.0.0-20200618092958-4fad48b4c8c3 @@ -74,7 +74,7 @@ require ( golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 golang.org/x/sys v0.0.0-20200819171115-d785dc25833f - golang.org/x/text v0.3.4 + golang.org/x/text v0.3.5 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.0.0-20200820010801-b793a1359eac golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index 653411200fc3e..9d904cff63883 100644 --- a/go.sum +++ b/go.sum @@ -396,15 +396,15 @@ github.com/pingcap/kvproto v0.0.0-20191211054548-3c6b38ea5107/go.mod h1:WWLmULLO github.com/pingcap/kvproto v0.0.0-20200411081810-b85805c9476c/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= github.com/pingcap/kvproto v0.0.0-20200907074027-32a3a0accf7d h1:gvJScINTd/HFasp82W1paGTfbYe2Bnzn8QBOJXckLwY= github.com/pingcap/kvproto v0.0.0-20200907074027-32a3a0accf7d/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= -github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc h1:BtszN3YR5EScxiGGTD3tAf4CQE90bczkOY0lLa07EJA= -github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= +github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b h1:Jp0V5PDzdOy666n4XbDDaEjOKHsp2nk7b2uR6qjFI0s= +github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8 h1:M+DNpOu/I3uDmwee6vcnoPd6GgSMqND4gxvDQ/W584U= github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= -github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe h1:sukVKRva68HNGZ4nuPvQS/wMvH7NMxTXV2NIhmoYP4Y= -github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe/go.mod h1:GbEr2PgY72/4XqPZzmzstlOU/+il/wrjeTNFs6ihsSE= +github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed h1:+ENLMPNRG8+/YGNJChC5QRgfrcmFnsrHl9WoVLXRZok= +github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed/go.mod h1:GbEr2PgY72/4XqPZzmzstlOU/+il/wrjeTNFs6ihsSE= github.com/pingcap/sysutil v0.0.0-20200206130906-2bfa6dc40bcd/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966 h1:JI0wOAb8aQML0vAVLHcxTEEC0VIwrk6gtw3WjbHvJLA= github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= @@ -693,8 +693,8 @@ golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/infoschema/perfschema/const.go b/infoschema/perfschema/const.go index 86a5f1766739f..157ce4c7ae7c9 100644 --- a/infoschema/perfschema/const.go +++ b/infoschema/perfschema/const.go @@ -415,6 +415,7 @@ const tableEventsStatementsSummaryByDigest = "CREATE TABLE if not exists perform "LAST_SEEN timestamp(6) NOT NULL DEFAULT '0000-00-00 00:00:00.000000'," + "PLAN_IN_CACHE bool NOT NULL," + "PLAN_CACHE_HITS bigint unsigned NOT NULL," + + "PLAN_IN_BINDING bool NOT NULL," + "QUANTILE_95 bigint unsigned NOT NULL," + "QUANTILE_99 bigint unsigned NOT NULL," + "QUANTILE_999 bigint unsigned NOT NULL," + diff --git a/infoschema/tables.go b/infoschema/tables.go index 41a4f2b180ba4..845f78877f271 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -475,6 +475,7 @@ var partitionsCols = []columnInfo{ {name: "PARTITION_COMMENT", tp: mysql.TypeVarchar, size: 80}, {name: "NODEGROUP", tp: mysql.TypeVarchar, size: 12}, {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, } var tableConstraintsCols = []columnInfo{ @@ -727,6 +728,8 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogRewriteTimeStr, tp: mysql.TypeDouble, size: 22}, {name: variable.SlowLogPreprocSubQueriesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, {name: variable.SlowLogPreProcSubQueryTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogOptimizeTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWaitTSTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.PreWriteTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.WaitPrewriteBinlogTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.CommitTimeStr, tp: mysql.TypeDouble, size: 22}, @@ -770,6 +773,7 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogPrepared, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogSucc, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlanFromCache, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogPlanFromBinding, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, {name: variable.SlowLogPlanDigest, tp: mysql.TypeVarchar, size: 128}, {name: variable.SlowLogPrevStmt, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, @@ -1231,6 +1235,7 @@ var tableStatementsSummaryCols = []columnInfo{ {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the last time"}, {name: "PLAN_IN_CACHE", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement hit plan cache"}, {name: "PLAN_CACHE_HITS", tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times these statements hit plan cache"}, + {name: "PLAN_IN_BINDING", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement is matched with the hints in the binding"}, {name: "QUERY_SAMPLE_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled original statement"}, {name: "PREV_SAMPLE_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The previous statement before commit"}, {name: "PLAN_DIGEST", tp: mysql.TypeVarchar, size: 64, comment: "Digest of its execution plan"}, diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 072cd1079ba92..5e425730a9231 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -552,6 +552,8 @@ func prepareSlowLogfile(c *C, slowLogFileName string) { # Parse_time: 0.4 # Compile_time: 0.2 # Rewrite_time: 0.000000003 Preproc_subqueries: 2 Preproc_subqueries_time: 0.000000002 +# Optimize_time: 0.00000001 +# Wait_TS: 0.000000003 # LockKeys_time: 1.71 Request_count: 1 Prewrite_time: 0.19 Wait_prewrite_binlog_time: 0.21 Commit_time: 0.01 Commit_backoff_time: 0.18 Backoff_types: [txnLock] Resolve_lock_time: 0.03 Write_keys: 15 Write_size: 480 Prewrite_region: 1 Txn_retry: 8 # Cop_time: 0.3824278 Process_time: 0.161 Request_count: 1 Total_keys: 100001 Process_keys: 100000 # Wait_time: 0.101 @@ -636,10 +638,10 @@ func (s *testTableSuite) TestSlowQuery(c *C) { tk.MustExec("set time_zone = '+08:00';") re := tk.MustQuery("select * from information_schema.slow_query") re.Check(testutil.RowsWithSep("|", - "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) tk.MustExec("set time_zone = '+00:00';") re = tk.MustQuery("select * from information_schema.slow_query") - re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) // Test for long query. f, err := os.OpenFile(slowLogFileName, os.O_CREATE|os.O_WRONLY, 0644) diff --git a/metrics/metrics.go b/metrics/metrics.go index 560557e723de8..75125d664d8d5 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -173,4 +173,6 @@ func RegisterMetrics() { prometheus.MustRegister(ServerInfo) prometheus.MustRegister(TokenGauge) prometheus.MustRegister(ConfigStatus) + prometheus.MustRegister(SmallTxnWriteDuration) + prometheus.MustRegister(TxnWriteThroughput) } diff --git a/metrics/sli.go b/metrics/sli.go new file mode 100644 index 0000000000000..2e926de099997 --- /dev/null +++ b/metrics/sli.go @@ -0,0 +1,40 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +var ( + // SmallTxnWriteDuration uses to collect small transaction write duration. + SmallTxnWriteDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "sli", + Name: "small_txn_write_duration_seconds", + Help: "Bucketed histogram of small transaction write time (s).", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 28), // 1ms ~ 74h + }) + + // TxnWriteThroughput uses to collect transaction write throughput which transaction is not small. + TxnWriteThroughput = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "sli", + Name: "txn_write_throughput", + Help: "Bucketed histogram of transaction write throughput (bytes/second).", + Buckets: prometheus.ExponentialBuckets(64, 1.3, 40), // 64 bytes/s ~ 2.3MB/s + }) +) diff --git a/planner/cascades/testdata/transformation_rules_suite_out.json b/planner/cascades/testdata/transformation_rules_suite_out.json index 36fe4b319c1c8..645fe6ac0ea97 100644 --- a/planner/cascades/testdata/transformation_rules_suite_out.json +++ b/planner/cascades/testdata/transformation_rules_suite_out.json @@ -2324,7 +2324,7 @@ "Group#2 Schema:[Column#13,Column#14,Column#15,test.t.b,Column#16,Column#17,Column#18,Column#19,Column#20,Column#14,Column#13]", " Selection_4 input:[Group#3], ge(Column#13, 0), ge(Column#14, 0)", "Group#3 Schema:[Column#13,Column#14,Column#15,test.t.b,Column#16,Column#17,Column#18,Column#19,Column#20,Column#14,Column#13]", - " Projection_8 input:[Group#4], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(65,30) BINARY)->Column#15, test.t.b, cast(test.t.b, int(11))->Column#16, cast(test.t.b, int(11))->Column#17, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#20, cast(test.t.b, decimal(65,0) BINARY)->Column#14, 1->Column#13", + " Projection_8 input:[Group#4], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(15,4) BINARY)->Column#15, test.t.b, cast(test.t.b, int(11))->Column#16, cast(test.t.b, int(11))->Column#17, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#20, cast(test.t.b, decimal(65,0) BINARY)->Column#14, 1->Column#13", "Group#4 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", " DataSource_1 table:t" ] @@ -2333,7 +2333,7 @@ "SQL": "select count(b), sum(b), avg(b), f, max(c), min(c), bit_and(c), bit_or(d), bit_xor(g) from t group by a", "Result": [ "Group#0 Schema:[Column#13,Column#14,Column#15,test.t.f,Column#16,Column#17,Column#18,Column#19,Column#20]", - " Projection_5 input:[Group#1], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(65,30) BINARY)->Column#15, test.t.f, cast(test.t.c, int(11))->Column#16, cast(test.t.c, int(11))->Column#17, ifnull(cast(test.t.c, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.d, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.g, bigint(21) UNSIGNED BINARY), 0)->Column#20", + " Projection_5 input:[Group#1], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(15,4) BINARY)->Column#15, test.t.f, cast(test.t.c, int(11))->Column#16, cast(test.t.c, int(11))->Column#17, ifnull(cast(test.t.c, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.d, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.g, bigint(21) UNSIGNED BINARY), 0)->Column#20", "Group#1 Schema:[test.t.a,test.t.b,test.t.c,test.t.d,test.t.f,test.t.g], UniqueKey:[test.t.f,test.t.f,test.t.g,test.t.a]", " DataSource_1 table:t" ] diff --git a/planner/core/cache.go b/planner/core/cache.go index 106ab9e84fe9a..9f31a08392a81 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -67,7 +67,6 @@ type pstmtPlanCacheKey struct { database string connID uint64 pstmtID uint32 - snapshot uint64 schemaVersion int64 sqlMode mysql.SQLMode timezoneOffset int @@ -90,7 +89,6 @@ func (key *pstmtPlanCacheKey) Hash() []byte { key.hash = append(key.hash, dbBytes...) key.hash = codec.EncodeInt(key.hash, int64(key.connID)) key.hash = codec.EncodeInt(key.hash, int64(key.pstmtID)) - key.hash = codec.EncodeInt(key.hash, int64(key.snapshot)) key.hash = codec.EncodeInt(key.hash, key.schemaVersion) key.hash = codec.EncodeInt(key.hash, int64(key.sqlMode)) key.hash = codec.EncodeInt(key.hash, int64(key.timezoneOffset)) @@ -134,7 +132,6 @@ func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, sch database: sessionVars.CurrentDB, connID: sessionVars.ConnectionID, pstmtID: pstmtID, - snapshot: sessionVars.SnapshotTS, schemaVersion: schemaVersion, sqlMode: sessionVars.SQLMode, timezoneOffset: timezoneOffset, diff --git a/planner/core/cache_test.go b/planner/core/cache_test.go index 8ff56ddb97e66..9f68dc27f286f 100644 --- a/planner/core/cache_test.go +++ b/planner/core/cache_test.go @@ -39,5 +39,5 @@ func (s *testCacheSuite) SetUpSuite(c *C) { func (s *testCacheSuite) TestCacheKey(c *C) { defer testleak.AfterTest(c)() key := NewPSTMTPlanCacheKey(s.ctx.GetSessionVars(), 1, 1) - c.Assert(key.Hash(), DeepEquals, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68}) + c.Assert(key.Hash(), DeepEquals, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68}) } diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 3d42bcf1fc8b1..1cd6e4d0037c0 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -444,7 +444,7 @@ func (s *testAnalyzeSuite) TestPreparedNullParam(c *C) { testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") testKit.MustExec("drop table if exists t") - testKit.MustExec("create table t (id int, KEY id (id))") + testKit.MustExec("create table t (id int not null, KEY id (id))") testKit.MustExec("insert into t values (1), (2), (3)") sql := "select * from t where id = ?" diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 186993c2cd599..a19304db7c883 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -188,6 +188,21 @@ type Execute struct { Plan Plan } +// Check if result of GetVar expr is BinaryLiteral +// Because GetVar use String to represent BinaryLiteral, here we need to convert string back to BinaryLiteral. +func isGetVarBinaryLiteral(sctx sessionctx.Context, expr expression.Expression) (res bool) { + scalarFunc, ok := expr.(*expression.ScalarFunction) + if ok && scalarFunc.FuncName.L == ast.GetVar { + name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx, chunk.Row{}) + if err != nil || isNull { + res = false + } else if dt, ok2 := sctx.GetSessionVars().Users[name]; ok2 { + res = (dt.Kind() == types.KindBinaryLiteral) + } + } + return res +} + // OptimizePreparedPlan optimizes the prepared statement. func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error { vars := sctx.GetSessionVars() @@ -229,6 +244,13 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont return err } param := prepared.Params[i].(*driver.ParamMarkerExpr) + if isGetVarBinaryLiteral(sctx, usingVar) { + binVal, convErr := val.ToBytes() + if convErr != nil { + return convErr + } + val.SetBinaryLiteral(types.BinaryLiteral(binVal)) + } param.Datum = val param.InExecute = true vars.PreparedParams = append(vars.PreparedParams, val) @@ -1012,7 +1034,11 @@ func (e *Explain) explainPlanInRowFormat(p Plan, taskType, driverSide, indent st return errors.Errorf("the store type %v is unknown", x.StoreType) } storeType = x.StoreType.Name() - err = e.explainPlanInRowFormat(x.tablePlan, "cop["+storeType+"]", "", childIndent, true) + taskName := "cop" + if x.BatchCop { + taskName = "batchCop" + } + err = e.explainPlanInRowFormat(x.tablePlan, taskName+"["+storeType+"]", "", childIndent, true) case *PhysicalIndexReader: err = e.explainPlanInRowFormat(x.indexPlan, "cop[tikv]", "", childIndent, true) case *PhysicalIndexLookUpReader: diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 0cb1317116762..e36eb6efac784 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1356,7 +1356,8 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field er.ctxStackAppend(expression.NewNull(), types.EmptyName) return } - if leftEt == types.ETInt { + containMut := expression.ContainMutableConst(er.sctx, args) + if !containMut && leftEt == types.ETInt { for i := 1; i < len(args); i++ { if c, ok := args[i].(*expression.Constant); ok { var isExceptional bool @@ -1547,6 +1548,11 @@ func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Exp } return expression.WrapWithCastAsString(ctx, e) } + case types.ETDuration: + expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDuration)) + lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDuration)) + rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDuration)) + return case types.ETDatetime: expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime)) lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime)) diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 874aba58f7686..66bb860a52f0b 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -377,3 +377,21 @@ func (s *testExpressionRewriterSuite) TestCompareMultiFieldsInSubquery(c *C) { tk.MustQuery("SELECT * FROM t3 WHERE (SELECT c1, c2 FROM t3 LIMIT 1) != ALL(SELECT c1, c2 FROM t4);").Check(testkit.Rows()) } + +func (s *testExpressionRewriterSuite) TestIssue22818(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a time);") + tk.MustExec("insert into t values(\"23:22:22\");") + tk.MustQuery("select * from t where a between \"23:22:22\" and \"23:22:22\"").Check( + testkit.Rows("23:22:22")) +} diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 7e8f0587f9902..59136c557b1a6 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -40,6 +40,10 @@ const ( // of a Selection or a JoinCondition, we can use this default value. SelectionFactor = 0.8 distinctFactor = 0.8 + + // If the actual row count is much more than the limit count, the unordered scan may cost much more than keep order. + // So when a limit exists, we don't apply the DescScanFactor. + smallScanThreshold = 10000 ) var aggFuncFactor = map[string]float64{ @@ -1184,6 +1188,9 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid return invalidTask, nil } ts, cost, _ := ds.getOriginalPhysicalTableScan(prop, candidate.path, candidate.isMatchProp) + if ts.KeepOrder && ts.Desc && ts.StoreType == kv.TiFlash { + return invalidTask, nil + } copTask := &copTask{ tablePlan: ts, indexPlanFinished: true, @@ -1423,8 +1430,8 @@ func (ds *DataSource) getOriginalPhysicalTableScan(prop *property.PhysicalProper cost += rowCount * sessVars.NetworkFactor * rowSize } if isMatchProp { - if prop.Items[0].Desc { - ts.Desc = true + ts.Desc = prop.Items[0].Desc + if prop.Items[0].Desc && prop.ExpectedCnt >= smallScanThreshold { cost = rowCount * rowSize * sessVars.DescScanFactor } ts.KeepOrder = true @@ -1472,8 +1479,8 @@ func (ds *DataSource) getOriginalPhysicalIndexScan(prop *property.PhysicalProper sessVars := ds.ctx.GetSessionVars() cost := rowCount * rowSize * sessVars.ScanFactor if isMatchProp { - if prop.Items[0].Desc { - is.Desc = true + is.Desc = prop.Items[0].Desc + if prop.Items[0].Desc && prop.ExpectedCnt >= smallScanThreshold { cost = rowCount * rowSize * sessVars.DescScanFactor } is.KeepOrder = true diff --git a/planner/core/initialize.go b/planner/core/initialize.go index acea1c78de2d0..d238cf038dd7e 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -15,6 +15,7 @@ package core import ( "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -428,6 +429,21 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalT if p.tablePlan != nil { p.TablePlans = flattenPushDownPlan(p.tablePlan) p.schema = p.tablePlan.Schema() + if p.StoreType == kv.TiFlash && !p.GetTableScan().KeepOrder { + // When allow batch cop is 1, only agg / topN uses batch cop. + // When allow batch cop is 2, every query uses batch cop. + switch ctx.GetSessionVars().AllowBatchCop { + case 1: + for _, plan := range p.TablePlans { + switch plan.(type) { + case *PhysicalHashAgg, *PhysicalStreamAgg, *PhysicalTopN, *PhysicalBroadCastJoin: + p.BatchCop = true + } + } + case 2: + p.BatchCop = true + } + } } return &p } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 67d93f910824f..2b0e487958e2b 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -349,6 +349,43 @@ func (s *testIntegrationSerialSuite) TestSelPushDownTiFlash(c *C) { } } +func (s *testIntegrationSerialSuite) TestPushDownToTiFlashWithKeepOrder(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b varchar(20))") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSerialSuite) TestBroadcastJoin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -441,8 +478,8 @@ func (s *testIntegrationSerialSuite) TestAggPushDownEngine(c *C) { tk.MustQuery("desc select approx_count_distinct(a) from t").Check(testkit.Rows( "StreamAgg_16 1.00 root funcs:approx_count_distinct(Column#5)->Column#3", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:approx_count_distinct(test.t.a)->Column#5", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo")) + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:approx_count_distinct(test.t.a)->Column#5", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo")) tk.MustExec("set @@session.tidb_isolation_read_engines = 'tikv'") @@ -1145,10 +1182,10 @@ func (s *testIntegrationSerialSuite) TestIssue16837(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int,b int,c int,d int,e int,unique key idx_ab(a,b),unique key(c),unique key(d))") tk.MustQuery("explain select /*+ use_index_merge(t,c,idx_ab) */ * from t where a = 1 or (e = 1 and c = 1)").Check(testkit.Rows( - "IndexMerge_9 0.01 root ", + "IndexMerge_9 8.80 root ", "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:idx_ab(a, b) range:[1,1], keep order:false, stats:pseudo", "├─IndexRangeScan_6(Build) 1.00 cop[tikv] table:t, index:c(c) range:[1,1], keep order:false, stats:pseudo", - "└─Selection_8(Probe) 0.01 cop[tikv] eq(test.t.e, 1)", + "└─Selection_8(Probe) 8.80 cop[tikv] or(eq(test.t.a, 1), and(eq(test.t.e, 1), eq(test.t.c, 1)))", " └─TableRowIDScan_7 11.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustExec("insert into t values (2, 1, 1, 1, 2)") @@ -1186,10 +1223,13 @@ func (s *testIntegrationSerialSuite) TestIndexMerge(c *C) { tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustQuery("desc select /*+ use_index_merge(t) */ * from t where (a=1 and length(b)=1) or (b=1 and length(a)=1)").Check(testkit.Rows( - "TableReader_7 8000.00 root data:Selection_6", - "└─Selection_6 8000.00 cop[tikv] or(and(eq(test.t.a, 1), eq(length(cast(test.t.b)), 1)), and(eq(test.t.b, 1), eq(length(cast(test.t.a)), 1)))", - " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 IndexMerge is inapplicable or disabled")) + "IndexMerge_9 1.60 root ", + "├─IndexRangeScan_5(Build) 1.00 cop[tikv] table:t, index:a(a) range:[1,1], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 1.00 cop[tikv] table:t, index:b(b) range:[1,1], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 1.60 cop[tikv] or(and(eq(test.t.a, 1), eq(length(cast(test.t.b)), 1)), and(eq(test.t.b, 1), eq(length(cast(test.t.a)), 1)))", + " └─TableRowIDScan_7 2.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("show warnings").Check(testkit.Rows()) } func (s *testIntegrationSerialSuite) TestIssue16407(c *C) { @@ -1198,10 +1238,10 @@ func (s *testIntegrationSerialSuite) TestIssue16407(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int,b char(100),key(a),key(b(10)))") tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where a=10 or b='x'").Check(testkit.Rows( - "IndexMerge_9 0.02 root ", + "IndexMerge_9 16.00 root ", "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[\"x\",\"x\"], keep order:false, stats:pseudo", - "└─Selection_8(Probe) 0.02 cop[tikv] eq(test.t.b, \"x\")", + "└─Selection_8(Probe) 16.00 cop[tikv] or(eq(test.t.a, 10), eq(test.t.b, \"x\"))", " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustExec("insert into t values (1, 'xx')") @@ -1800,6 +1840,27 @@ func (s *testIntegrationSuite) TestIssue22105(c *C) { } } +func (s *testIntegrationSerialSuite) TestLimitIndexLookUpKeepOrder(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int, b int, c int, d int, index idx(a,b,c));") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { tk := testkit.NewTestKit(c, s.store) @@ -1823,3 +1884,155 @@ func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) } } + +func (s *testIntegrationSerialSuite) TestDeleteStmt(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("delete t from t;") + tk.MustExec("delete t from test.t as t;") + tk.MustGetErrCode("delete test.t from test.t as t;", mysql.ErrUnknownTable) + tk.MustExec("delete test.t from t;") + tk.MustExec("create database db1") + tk.MustExec("use db1") + tk.MustExec("create table t(a int)") + tk.MustGetErrCode("delete test.t from t;", mysql.ErrUnknownTable) +} + +// Test for issue https://github.com/pingcap/tidb/issues/21607. +func (s *testIntegrationSuite) TestConditionColPruneInPhysicalUnionScan(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b int);") + tk.MustExec("begin;") + tk.MustExec("insert into t values (1, 2);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + + tk.MustExec("drop table t;") + tk.MustExec("create table t (a int, b int as (a + 1), c int as (b + 1));") + tk.MustExec("begin;") + tk.MustExec("insert into t (a) values (1);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + tk.MustQuery("select count(*) from t where c = 1 and c in (3);"). + Check(testkit.Rows("0")) +} + +func (s *testIntegrationSuite) TestIssue22071(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t (a int);") + tk.MustExec("insert into t values(1),(2),(5)") + tk.MustQuery("select n in (1,2) from (select a in (1,2) as n from t) g;").Sort().Check(testkit.Rows("0", "1", "1")) + tk.MustQuery("select n in (1,n) from (select a in (1,2) as n from t) g;").Check(testkit.Rows("1", "1", "1")) +} + +func (s *testIntegrationSuite) TestIndexMergeTableFilter(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int, b int, c int, d int, key(a), key(b));") + tk.MustExec("insert into t values(10,1,1,10)") + + tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where a=10 or (b=10 and c=10)").Check(testkit.Rows( + "IndexMerge_9 16.00 root ", + "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[10,10], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 16.00 cop[tikv] or(eq(test.t.a, 10), and(eq(test.t.b, 10), eq(test.t.c, 10)))", + " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("select /*+ use_index_merge(t) */ * from t where a=10 or (b=10 and c=10)").Check(testkit.Rows( + "10 1 1 10", + )) + tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where (a=10 and d=10) or (b=10 and c=10)").Check(testkit.Rows( + "IndexMerge_9 16.00 root ", + "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[10,10], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 16.00 cop[tikv] or(and(eq(test.t.a, 10), eq(test.t.d, 10)), and(eq(test.t.b, 10), eq(test.t.c, 10)))", + " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("select /*+ use_index_merge(t) */ * from t where (a=10 and d=10) or (b=10 and c=10)").Check(testkit.Rows( + "10 1 1 10", + )) +} + +// #22949: test HexLiteral Used in GetVar expr +func (s *testIntegrationSuite) TestGetVarExprWithHexLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1_no_idx;") + tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));") + tk.MustExec("insert into t1_no_idx values(1, 0x3135);") + tk.MustExec("insert into t1_no_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use table with index on col_bit + tk.MustExec("drop table if exists t2_idx;") + tk.MustExec("create table t2_idx(id int, col_bit bit(16), key(col_bit));") + tk.MustExec("insert into t2_idx values(1, 0x3135);") + tk.MustExec("insert into t2_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit in (?)';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // test col varchar with GetVar + tk.MustExec("drop table if exists t_varchar;") + tk.MustExec("create table t_varchar(id int, col_varchar varchar(100), key(col_varchar));") + tk.MustExec("insert into t_varchar values(1, '15');") + tk.MustExec("prepare stmt from 'select id from t_varchar where col_varchar = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) +} + +// test BitLiteral used with GetVar +func (s *testIntegrationSuite) TestGetVarExprWithBitLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1_no_idx;") + tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));") + tk.MustExec("insert into t1_no_idx values(1, 0x3135);") + tk.MustExec("insert into t1_no_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';") + // 0b11000100110101 is 0x3135 + tk.MustExec("set @a = 0b11000100110101;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';") + tk.MustExec("set @a = 0b11000100110101;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) +} + +func (s *testIntegrationSuite) TestIssue23846(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a varbinary(10),UNIQUE KEY(a))") + tk.MustExec("insert into t values(0x00A4EEF4FA55D6706ED5)") + tk.MustQuery("select * from t where a=0x00A4EEF4FA55D6706ED5").Check(testkit.Rows("\x00\xa4\xee\xf4\xfaU\xd6pn\xd5")) // not empty +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ca66fd458c536..f608ec87c9037 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4296,53 +4296,50 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) ( return nil, err } - var tableList []*ast.TableName - tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) - // Collect visitInfo. if delete.Tables != nil { // Delete a, b from a, b, c, d... add a and b. + updatableList := make(map[string]bool) + tbInfoList := make(map[string]*ast.TableName) + collectTableName(delete.TableRefs.TableRefs, &updatableList, &tbInfoList) for _, tn := range delete.Tables.Tables { - foundMatch := false - for _, v := range tableList { - dbName := v.Schema - if dbName.L == "" { - dbName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) - } - if (tn.Schema.L == "" || tn.Schema.L == dbName.L) && tn.Name.L == v.Name.L { - tn.Schema = dbName - tn.DBInfo = v.DBInfo - tn.TableInfo = v.TableInfo - foundMatch = true - break - } + var canUpdate, foundMatch = false, false + name := tn.Name.L + if tn.Schema.L == "" { + canUpdate, foundMatch = updatableList[name] } + if !foundMatch { - var asNameList []string - asNameList = extractTableSourceAsNames(delete.TableRefs.TableRefs, asNameList, false) - for _, asName := range asNameList { - tblName := tn.Name.L - if tn.Schema.L != "" { - tblName = tn.Schema.L + "." + tblName - } - if asName == tblName { - // check sql like: `delete a from (select * from t) as a, t` - return nil, ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") - } + if tn.Schema.L == "" { + name = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB).L + "." + tn.Name.L + } else { + name = tn.Schema.L + "." + tn.Name.L } - // check sql like: `delete b from (select * from t) as a, t` + canUpdate, foundMatch = updatableList[name] + } + // check sql like: `delete b from (select * from t) as a, t` + if !foundMatch { return nil, ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") } + // check sql like: `delete a from (select * from t) as a, t` + if !canUpdate { + return nil, ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") + } + tb := tbInfoList[name] + tn.DBInfo = tb.DBInfo + tn.TableInfo = tb.TableInfo if tn.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", tn.Name.O) } if tn.TableInfo.IsSequence() { return nil, errors.Errorf("delete sequence %s is not supported now.", tn.Name.O) } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "", nil) + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tb.DBInfo.Name.L, tb.Name.L, "", nil) } } else { // Delete from a, b, c, d. + var tableList []*ast.TableName + tableList = extractTableList(delete.TableRefs.TableRefs, tableList, false) for _, v := range tableList { if v.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", v.Name.O) @@ -4423,7 +4420,7 @@ func (p *Delete) cleanTblID2HandleMap( // matchingDeletingTable checks whether this column is from the table which is in the deleting list. func (p *Delete) matchingDeletingTable(names []*ast.TableName, name *types.FieldName) bool { for _, n := range names { - if (name.DBName.L == "" || name.DBName.L == n.Schema.L) && name.TblName.L == n.Name.L { + if (name.DBName.L == "" || name.DBName.L == n.DBInfo.Name.L) && name.TblName.L == n.Name.L { return true } } @@ -5097,6 +5094,25 @@ func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName boo return input } +func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) { + switch x := node.(type) { + case *ast.Join: + collectTableName(x.Left, updatableName, info) + collectTableName(x.Right, updatableName, info) + case *ast.TableSource: + name := x.AsName.L + var canUpdate bool + var s *ast.TableName + if s, canUpdate = x.Source.(*ast.TableName); canUpdate { + if name == "" { + name = s.Schema.L + "." + s.Name.L + } + (*info)[name] = s + } + (*updatableName)[name] = canUpdate + } +} + // extractTableSourceAsNames extracts TableSource.AsNames from node. // if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string { diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go new file mode 100644 index 0000000000000..0bc478e6e742a --- /dev/null +++ b/planner/core/partition_pruner_test.go @@ -0,0 +1,66 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" +) + +var _ = Suite(&testPartitionPruneSuit{}) + +type testPartitionPruneSuit struct { + store kv.Storage + dom *domain.Domain + ctx sessionctx.Context + testData testutil.TestData +} + +func (s *testPartitionPruneSuit) cleanEnv(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test_partition") + r := tk.MustQuery("show tables") + for _, tb := range r.Rows() { + tableName := tb[0] + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } +} + +func (s *testPartitionPruneSuit) SetUpSuite(c *C) { + var err error + s.store, s.dom, err = newStoreWithBootstrap() + c.Assert(err, IsNil) + s.ctx = mock.NewContext() +} + +func (s *testPartitionPruneSuit) TearDownSuite(c *C) { + s.dom.Close() + s.store.Close() +} + +func (s *testPartitionPruneSuit) TestIssue23622(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t2 (a int, b int) partition by range (a) (partition p0 values less than (0), partition p1 values less than (5));") + tk.MustExec("insert into t2(a) values (-1), (1);") + tk.MustQuery("select * from t2 where a > 10 or b is NULL order by a;").Check(testkit.Rows("-1 ", "1 ")) +} diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 4f3fa07e0cc27..91a1cff9a420f 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -880,7 +880,7 @@ func (s *testPlanSuite) TestLimitToCopHint(c *C) { output []struct { SQL string Plan []string - Warning string + Warning []string } ) @@ -897,15 +897,20 @@ func (s *testPlanSuite) TestLimitToCopHint(c *C) { warnings := tk.Se.GetSessionVars().StmtCtx.GetWarnings() s.testData.OnRecord(func() { if len(warnings) > 0 { - output[i].Warning = warnings[0].Err.Error() + output[i].Warning = make([]string, len(warnings)) + for j, warning := range warnings { + output[i].Warning[j] = warning.Err.Error() + } } }) - if output[i].Warning == "" { + if len(output[i].Warning) == 0 { c.Assert(len(warnings), Equals, 0, comment) } else { - c.Assert(len(warnings), Equals, 1, comment) - c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment) - c.Assert(warnings[0].Err.Error(), Equals, output[i].Warning, comment) + c.Assert(len(warnings), Equals, len(output[i].Warning), comment) + for j, warning := range warnings { + c.Assert(warning.Level, Equals, stmtctx.WarnLevelWarning, comment) + c.Assert(warning.Err.Error(), Equals, output[i].Warning[j], comment) + } } } } diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 597bfeeb49fdc..3f2ebbb42e4c7 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -72,6 +72,9 @@ type PhysicalTableReader struct { // StoreType indicates table read from which type of store. StoreType kv.StoreType + + // BatchCop = true means the cop task in the physical table reader will be executed in batch mode(use in TiFlash only) + BatchCop bool } // GetTablePlan exports the tablePlan. @@ -186,6 +189,9 @@ type PhysicalIndexMergeReader struct { partialPlans []PhysicalPlan // tablePlan is a PhysicalTableScan to get the table tuples. Current, it must be not nil. tablePlan PhysicalPlan + // ExtraHandleCol indicates the index of extraHandleCol when the partial + // reader is TableReader. + ExtraHandleCol *expression.Column } // PhysicalIndexScan represents an index scan plan. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 8761aadbdb4bf..3b2a6193575a5 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -2237,15 +2237,31 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( return nil, ErrPartitionClauseOnNonpartitioned } + user := b.ctx.GetSessionVars().User var authErr error - if b.ctx.GetSessionVars().User != nil { - authErr = ErrTableaccessDenied.GenWithStackByArgs("INSERT", b.ctx.GetSessionVars().User.AuthUsername, - b.ctx.GetSessionVars().User.AuthHostname, tableInfo.Name.L) + if user != nil { + authErr = ErrTableaccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, tableInfo.Name.L) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, tn.DBInfo.Name.L, tableInfo.Name.L, "", authErr) + // `REPLACE INTO` requires both INSERT + DELETE privilege + // `ON DUPLICATE KEY UPDATE` requires both INSERT + UPDATE privilege + var extraPriv mysql.PrivilegeType + if insert.IsReplace { + extraPriv = mysql.DeletePriv + } else if insert.OnDuplicate != nil { + extraPriv = mysql.UpdatePriv + } + if extraPriv != 0 { + if user != nil { + cmd := strings.ToUpper(mysql.Priv2Str[extraPriv]) + authErr = ErrTableaccessDenied.GenWithStackByArgs(cmd, user.AuthUsername, user.AuthHostname, tableInfo.Name.L) + } + b.visitInfo = appendVisitInfo(b.visitInfo, extraPriv, tn.DBInfo.Name.L, tableInfo.Name.L, "", authErr) + } + mockTablePlan := LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) mockTablePlan.SetSchema(insertPlan.tableSchema) mockTablePlan.names = insertPlan.tableColNames diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index d78400cbc9648..a2a4477dd2e74 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -919,7 +919,7 @@ func buildSchemaFromFields( if col == nil { return nil, nil } - asName := col.Name + asName := colNameExpr.Name.Name if field.AsName.L != "" { asName = field.AsName } @@ -927,6 +927,7 @@ func buildSchemaFromFields( DBName: dbName, OrigTblName: tbl.Name, TblName: tblName, + OrigColName: col.Name, ColName: asName, }) columns = append(columns, colInfoToColumn(col, len(columns))) @@ -1040,7 +1041,8 @@ func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, } } // The converted result must be same as original datum. - cmp, err := d.CompareDatum(stmtCtx, &dVal) + // Compare them based on the dVal's type. + cmp, err := dVal.CompareDatum(stmtCtx, &d) if err != nil { return nil, false } else if cmp != 0 { diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index b6140f1ada2e7..840ce7ce15735 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -522,3 +522,24 @@ func (s *testPointGetSuite) TestCBOShouldNotUsePointGet(c *C) { res.Check(testkit.Rows(output[i].Res...)) } } + +func (s *testPointGetSuite) TestIssue23511(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("CREATE TABLE `t1` (`COL1` bit(11) NOT NULL,PRIMARY KEY (`COL1`));") + tk.MustExec("CREATE TABLE `t2` (`COL1` bit(11) NOT NULL);") + tk.MustExec("insert into t1 values(b'00000111001'), (b'00000000000');") + tk.MustExec("insert into t2 values(b'00000111001');") + tk.MustQuery("select * from t1 where col1 = 0x39;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x39;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t1 where col1 = 0x00;").Check(testkit.Rows("\x00\x00")) + tk.MustQuery("select * from t1 where col1 = 0x0000;").Check(testkit.Rows("\x00\x00")) + tk.MustQuery("explain select * from t1 where col1 = 0x39;"). + Check(testkit.Rows("Point_Get_1 1.00 root table:t1, index:PRIMARY(COL1) ")) + tk.MustQuery("select * from t1 where col1 = 0x0039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x0039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t1 where col1 = 0x000039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x000039;").Check(testkit.Rows("\x009")) + tk.MustExec("drop table t1, t2;") +} diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 53b4a53eb07d9..ca1ddb8424768 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -626,6 +626,18 @@ func (s *testPrepareSerialSuite) TestConstPropAndPPDWithCache(c *C) { tk.MustQuery("execute stmt using @p0").Check(testkit.Rows( "0", )) + + // Need to check if contain mutable before RefineCompareConstant() in inToExpression(). + // Otherwise may hit wrong plan. + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 tinyint unsigned);") + tk.MustExec("insert into t1 values(111);") + tk.MustExec("prepare stmt from 'select 1 from t1 where c1 in (?)';") + tk.MustExec("set @a = '1.1';") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows()) + tk.MustExec("set @a = '111';") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) } func (s *testPlanSerialSuite) TestPlanCacheUnionScan(c *C) { @@ -893,3 +905,53 @@ func (s *testPrepareSerialSuite) TestPrepareCacheWithJoinTable(c *C) { tk.MustQuery("execute stmt using @a").Check(testkit.Rows()) tk.MustQuery("execute stmt using @b").Check(testkit.Rows("a ")) } + +func (s *testPlanSerialSuite) TestPlanCacheSnapshot(c *C) { + store, _, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + orgEnable := core.PreparedPlanCacheEnabled() + defer func() { + store.Close() + core.SetPreparedPlanCache(orgEnable) + }() + core.SetPreparedPlanCache(true) + + tk.Se, err = session.CreateSession4TestWithOpt(store, &session.Opt{ + PreparedPlanCache: kvcache.NewSimpleLRUCache(100, 0.1, math.MaxUint64), + }) + c.Assert(err, IsNil) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int)") + tk.MustExec("insert into t values (1),(2),(3),(4)") + + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + timeSafe := time.Now().Add(-48 * 60 * 60 * time.Second).Format("20060102-15:04:05 -0700 MST") + safePointSQL := `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('tikv_gc_safe_point', '%[1]s', '') + ON DUPLICATE KEY + UPDATE variable_value = '%[1]s'` + tk.MustExec(fmt.Sprintf(safePointSQL, timeSafe)) + + tk.MustExec("prepare stmt from 'select * from t where id=?'") + tk.MustExec("set @p = 1") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + // Record the current tso. + tk.MustExec("begin") + tso := tk.Se.GetSessionVars().TxnCtx.StartTS + tk.MustExec("rollback") + c.Assert(tso > 0, IsTrue) + // Insert one more row with id = 1. + tk.MustExec("insert into t values (1)") + + tk.MustExec(fmt.Sprintf("set @@tidb_snapshot = '%d'", tso)) + tk.MustQuery("select * from t where id = 1").Check(testkit.Rows("1")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) +} diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index e37d51e57c1e0..974592a2701cc 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -365,6 +365,14 @@ func (p *PhysicalIndexMergeReader) ResolveIndices() (err error) { } } for i := 0; i < len(p.partialPlans); i++ { + switch x := p.partialPlans[i].(type) { + case *PhysicalTableReader: + newCol, err := p.ExtraHandleCol.ResolveIndices(x.Schema()) + if err != nil { + return err + } + p.ExtraHandleCol = newCol.(*expression.Column) + } err = p.partialPlans[i].ResolveIndices() if err != nil { return err diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 80ea3c5775009..7bc633bdfdb3d 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -217,6 +217,8 @@ func (p *LogicalUnionAll) PruneColumns(parentUsedCols []*expression.Column) erro // PruneColumns implements LogicalPlan interface. func (p *LogicalUnionScan) PruneColumns(parentUsedCols []*expression.Column) error { parentUsedCols = append(parentUsedCols, p.handleCol) + condCols := expression.ExtractColumnsFromExpressions(nil, p.conditions, nil) + parentUsedCols = append(parentUsedCols, condCols...) return p.children[0].PruneColumns(parentUsedCols) } @@ -301,17 +303,7 @@ func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (left } func (p *LogicalJoin) mergeSchema() { - lChild := p.children[0] - rChild := p.children[1] - if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { - p.schema = lChild.Schema().Clone() - } else if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - joinCol := p.schema.Columns[len(p.schema.Columns)-1] - p.schema = lChild.Schema().Clone() - p.schema.Append(joinCol) - } else { - p.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema()) - } + p.schema = buildLogicalJoinSchema(p.JoinType, p) } // PruneColumns implements LogicalPlan interface. diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 404c5e01dd77c..51ab76a34dc7c 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -87,6 +87,9 @@ func doPhysicalProjectionElimination(p PhysicalPlan) PhysicalPlan { return p } child := p.Children()[0] + if childProj, ok := child.(*PhysicalProjection); ok { + childProj.SetSchema(p.Schema()) + } return child } diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 12b7d53ab9f80..be40d793e77bb 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -266,6 +266,10 @@ func (or partitionRangeOR) union(x partitionRangeOR) partitionRangeOR { } func (or partitionRangeOR) simplify() partitionRangeOR { + // if the length of the `or` is zero. We should return early. + if len(or) == 0 { + return or + } // Make the ranges order by start. sort.Sort(or) sorted := or @@ -396,8 +400,10 @@ func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, args := fn.GetArgs() if len(args) > 0 { arg0 := args[0] - if c, ok1 := arg0.(*expression.Column); ok1 { - col = c + if expression.ExtractColumnSet(args).Len() == 1 { + if c, ok1 := arg0.(*expression.Column); ok1 { + col = c + } } } } @@ -509,7 +515,16 @@ func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expressi default: return pruner.fullRange() } - val, err := constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + + var val int64 + var err error + if pruner.partFn != nil { + // replace fn(col) to fn(const) + partFnConst := replaceColumnWithConst(pruner.partFn, constExpr) + val, _, err = partFnConst.EvalInt(sctx, chunk.Row{}) + } else { + val, err = constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + } if err != nil { return pruner.fullRange() } @@ -524,6 +539,9 @@ func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expressi var monotoneIncFuncs = map[string]struct{}{ ast.ToDays: {}, ast.UnixTimestamp: {}, + // Only when the function form is fn(column, const) + ast.Plus: {}, + ast.Minus: {}, } // f(x) op const, op is > = < diff --git a/planner/core/stats.go b/planner/core/stats.go index a062fcf50b844..6a03e53b6ca97 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -464,16 +464,11 @@ func (ds *DataSource) buildIndexMergeOrPath(partialPaths []*util.AccessPath, cur indexMergePath := &util.AccessPath{PartialIndexPaths: partialPaths} indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[:current]...) indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[current+1:]...) - tableFilterCnt := 0 for _, path := range partialPaths { - // IndexMerge should not be used when the SQL is like 'select x from t WHERE (key1=1 AND key2=2) OR (key1=4 AND key3=6);'. - // Check issue https://github.com/pingcap/tidb/issues/22105 for details. + // If any partial path contains table filters, we need to keep the whole DNF filter in the Selection. if len(path.TableFilters) > 0 { - tableFilterCnt++ - if tableFilterCnt > 1 { - return nil - } - indexMergePath.TableFilters = append(indexMergePath.TableFilters, path.TableFilters...) + indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[current]) + break } } return indexMergePath diff --git a/planner/core/task.go b/planner/core/task.go index f4ba96cc9e0e2..3789a9f00dd50 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -16,6 +16,7 @@ package core import ( "math" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" @@ -724,7 +725,7 @@ func finishCopTask(ctx sessionctx.Context, task task) task { cst: t.cst, } if t.idxMergePartPlans != nil { - p := PhysicalIndexMergeReader{partialPlans: t.idxMergePartPlans, tablePlan: t.tablePlan}.Init(ctx, t.idxMergePartPlans[0].SelectBlockOffset()) + p := PhysicalIndexMergeReader{partialPlans: t.idxMergePartPlans, tablePlan: t.tablePlan, ExtraHandleCol: t.extraHandleCol}.Init(ctx, t.idxMergePartPlans[0].SelectBlockOffset()) setTableScanToTableRowIDScan(p.tablePlan) newTask.p = p return newTask @@ -1043,6 +1044,13 @@ func CheckAggCanPushCop(sctx sessionctx.Context, aggFuncs []*aggregation.AggFunc return false } if !aggregation.CheckAggPushDown(aggFunc, storeType) { + if sc.InExplainStmt { + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + sc.AppendWarning(errors.New("Agg function '" + aggFunc.Name + "' can not be pushed to " + storageName)) + } return false } if !expression.CanExprsPushDown(sc, aggFunc.Args, client, storeType) { diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index f7a561e28c73f..3103392597700 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -8,6 +8,13 @@ "explain select * from t where b > 'a' order by b limit 2" ] }, + { + "name": "TestPushDownToTiFlashWithKeepOrder", + "cases": [ + "explain select max(a) from t", + "explain select min(a) from t" + ] + }, { "name": "TestBroadcastJoin", "cases": [ @@ -89,5 +96,12 @@ "explain select /*+ inl_hash_join(s) */ * from t join s on t.a=s.a and t.b = s.a", "explain select /*+ inl_hash_join(s) */ * from t join s on t.a=s.a and t.a = s.b" ] + }, + { + "name": "TestLimitIndexLookUpKeepOrder", + "cases": [ + "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b,c limit 10", + "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b desc, c desc limit 10" + ] } ] diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index a26e1116bfc94..529feb4c091c0 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -25,9 +25,9 @@ "└─TopN_8 2.00 root Column#3:asc, offset:0, count:2", " └─Projection_19 2.00 root test.t.a, test.t.b, cast(test.t.b, bigint(22) UNSIGNED BINARY)->Column#3", " └─TableReader_14 2.00 root data:TopN_13", - " └─TopN_13 2.00 cop[tiflash] cast(test.t.b):asc, offset:0, count:2", - " └─Selection_12 3333.33 cop[tiflash] gt(test.t.b, \"a\")", - " └─TableFullScan_11 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─TopN_13 2.00 batchCop[tiflash] cast(test.t.b):asc, offset:0, count:2", + " └─Selection_12 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", + " └─TableFullScan_11 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -35,9 +35,34 @@ "Result": [ "TopN_8 2.00 root test.t.b:asc, offset:0, count:2", "└─TableReader_17 2.00 root data:TopN_16", - " └─TopN_16 2.00 cop[tiflash] test.t.b:asc, offset:0, count:2", - " └─Selection_15 3333.33 cop[tiflash] gt(test.t.b, \"a\")", - " └─TableFullScan_14 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─TopN_16 2.00 batchCop[tiflash] test.t.b:asc, offset:0, count:2", + " └─Selection_15 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", + " └─TableFullScan_14 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + ] + } + ] + }, + { + "Name": "TestPushDownToTiFlashWithKeepOrder", + "Cases": [ + { + "SQL": "explain select max(a) from t", + "Plan": [ + "StreamAgg_9 1.00 root funcs:max(test.t.a)->Column#3", + "└─TopN_10 1.00 root test.t.a:desc, offset:0, count:1", + " └─TableReader_18 1.00 root data:TopN_17", + " └─TopN_17 1.00 batchCop[tiflash] test.t.a:desc, offset:0, count:1", + " └─TableFullScan_16 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select min(a) from t", + "Plan": [ + "StreamAgg_9 1.00 root funcs:min(test.t.a)->Column#3", + "└─Limit_13 1.00 root offset:0, count:1", + " └─TableReader_23 1.00 root data:Limit_22", + " └─Limit_22 1.00 cop[tiflash] offset:0, count:1", + " └─TableFullScan_21 1.00 cop[tiflash] table:t keep order:true, stats:pseudo" ] } ] @@ -50,12 +75,12 @@ "Plan": [ "StreamAgg_24 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_25 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_23 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_19(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_18 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_17(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_16 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_9 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_23 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_19(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_18 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_17(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_16 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -63,18 +88,18 @@ "Plan": [ "StreamAgg_44 1.00 root funcs:count(Column#19)->Column#17", "└─TableReader_45 1.00 root data:StreamAgg_13", - " └─StreamAgg_13 1.00 cop[tiflash] funcs:count(1)->Column#19", - " └─BroadcastJoin_43 8.00 cop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", - " ├─Selection_39(Build) 2.00 cop[tiflash] not(isnull(test.d3_t.d3_k))", - " │ └─TableFullScan_38 2.00 cop[tiflash] table:d3_t keep order:false, global read", - " └─BroadcastJoin_29(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", - " ├─Selection_25(Build) 2.00 cop[tiflash] not(isnull(test.d2_t.d2_k))", - " │ └─TableFullScan_24 2.00 cop[tiflash] table:d2_t keep order:false, global read", - " └─BroadcastJoin_33(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_23(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_22 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_37(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", - " └─TableFullScan_36 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_13 1.00 batchCop[tiflash] funcs:count(1)->Column#19", + " └─BroadcastJoin_43 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", + " ├─Selection_39(Build) 2.00 batchCop[tiflash] not(isnull(test.d3_t.d3_k))", + " │ └─TableFullScan_38 2.00 batchCop[tiflash] table:d3_t keep order:false, global read", + " └─BroadcastJoin_29(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", + " ├─Selection_25(Build) 2.00 batchCop[tiflash] not(isnull(test.d2_t.d2_k))", + " │ └─TableFullScan_24 2.00 batchCop[tiflash] table:d2_t keep order:false, global read", + " └─BroadcastJoin_33(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_23(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_22 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_37(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", + " └─TableFullScan_36 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -82,12 +107,12 @@ "Plan": [ "StreamAgg_18 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_19 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_17 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_14(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_13 2.00 cop[tiflash] table:d1_t keep order:false", - " └─Selection_12(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_11 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_9 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_17 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_14(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_13 2.00 batchCop[tiflash] table:d1_t keep order:false", + " └─Selection_12(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_11 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] }, { @@ -95,18 +120,18 @@ "Plan": [ "StreamAgg_29 1.00 root funcs:count(Column#19)->Column#17", "└─TableReader_30 1.00 root data:StreamAgg_13", - " └─StreamAgg_13 1.00 cop[tiflash] funcs:count(1)->Column#19", - " └─BroadcastJoin_28 8.00 cop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", - " ├─Selection_25(Build) 2.00 cop[tiflash] not(isnull(test.d3_t.d3_k))", - " │ └─TableFullScan_24 2.00 cop[tiflash] table:d3_t keep order:false, global read", - " └─BroadcastJoin_15(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", - " ├─Selection_23(Build) 2.00 cop[tiflash] not(isnull(test.d2_t.d2_k))", - " │ └─TableFullScan_22 2.00 cop[tiflash] table:d2_t keep order:false", - " └─BroadcastJoin_16(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_21(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_20 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_19(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", - " └─TableFullScan_18 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_13 1.00 batchCop[tiflash] funcs:count(1)->Column#19", + " └─BroadcastJoin_28 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", + " ├─Selection_25(Build) 2.00 batchCop[tiflash] not(isnull(test.d3_t.d3_k))", + " │ └─TableFullScan_24 2.00 batchCop[tiflash] table:d3_t keep order:false, global read", + " └─BroadcastJoin_15(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", + " ├─Selection_23(Build) 2.00 batchCop[tiflash] not(isnull(test.d2_t.d2_k))", + " │ └─TableFullScan_22 2.00 batchCop[tiflash] table:d2_t keep order:false", + " └─BroadcastJoin_16(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_21(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_20 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_19(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", + " └─TableFullScan_18 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] }, { @@ -114,11 +139,11 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_15 8.00 cop[tiflash] left outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_12(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_11 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─TableFullScan_10(Probe) 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_15 8.00 batchCop[tiflash] left outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_12(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_11 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─TableFullScan_10(Probe) 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -126,11 +151,11 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_15 8.00 cop[tiflash] right outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─TableFullScan_12(Build) 2.00 cop[tiflash] table:d1_t keep order:false", - " └─Selection_11(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_10 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_15 8.00 batchCop[tiflash] right outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─TableFullScan_12(Build) 2.00 batchCop[tiflash] table:d1_t keep order:false", + " └─Selection_11(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_10 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] } ] @@ -143,8 +168,8 @@ "Plan": [ "StreamAgg_24 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader_25 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan_22 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", + " └─TableFullScan_22 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -153,8 +178,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -163,8 +188,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(test.t.a)->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(test.t.a)->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -173,8 +198,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(plus(test.t.a, 1))->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(plus(test.t.a, 1))->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -183,8 +208,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(isnull(test.t.a))->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(isnull(test.t.a))->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -488,5 +513,32 @@ ] } ] + }, + { + "Name": "TestLimitIndexLookUpKeepOrder", + "Cases": [ + { + "SQL": "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b,c limit 10", + "Plan": [ + "Limit_12 0.00 root offset:0, count:10", + "└─Projection_34 0.00 root test.t.a, test.t.b, test.t.c, test.t.d", + " └─IndexLookUp_33 0.00 root ", + " ├─IndexRangeScan_30(Build) 2.50 cop[tikv] table:t, index:idx(a, b, c) range:(1 2,1 10), keep order:true, stats:pseudo", + " └─Selection_32(Probe) 0.00 cop[tikv] eq(test.t.d, 10)", + " └─TableRowIDScan_31 2.50 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b desc, c desc limit 10", + "Plan": [ + "Limit_12 0.00 root offset:0, count:10", + "└─Projection_34 0.00 root test.t.a, test.t.b, test.t.c, test.t.d", + " └─IndexLookUp_33 0.00 root ", + " ├─IndexRangeScan_30(Build) 2.50 cop[tikv] table:t, index:idx(a, b, c) range:(1 2,1 10), keep order:true, desc, stats:pseudo", + " └─Selection_32(Probe) 0.00 cop[tikv] eq(test.t.d, 10)", + " └─TableRowIDScan_31 2.50 cop[tikv] table:t keep order:false, stats:pseudo" + ] + } + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index fa38b1dc73407..6ebc5adc2e922 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -897,11 +897,12 @@ { "SQL": "explain SELECT /*+ use_index_merge(t1)*/ COUNT(*) FROM t1 WHERE (key4=42 AND key6 IS NOT NULL) OR (key1=4 AND key3=6)", "Plan": [ - "StreamAgg_20 1.00 root funcs:count(Column#12)->Column#10", - "└─TableReader_21 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#12", - " └─Selection_19 8000.00 cop[tikv] or(and(eq(test.t1.key4, 42), not(isnull(test.t1.key6))), and(eq(test.t1.key1, 4), eq(test.t1.key3, 6)))", - " └─TableFullScan_18 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + "HashAgg_8 1.00 root funcs:count(1)->Column#10", + "└─IndexMerge_15 16.00 root ", + " ├─IndexRangeScan_11(Build) 10.00 cop[tikv] table:t1, index:i4(key4) range:[42,42], keep order:false, stats:pseudo", + " ├─IndexRangeScan_12(Build) 10.00 cop[tikv] table:t1, index:i1(key1) range:[4,4], keep order:false, stats:pseudo", + " └─Selection_14(Probe) 16.00 cop[tikv] or(and(eq(test.t1.key4, 42), not(isnull(test.t1.key6))), and(eq(test.t1.key1, 4), eq(test.t1.key3, 6)))", + " └─TableRowIDScan_13 20.00 cop[tikv] table:t1 keep order:false, stats:pseudo" ] } ] diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index f16a5af4324d8..df5494298bdc2 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -1456,7 +1456,7 @@ " └─Selection_13 0.83 cop[tikv] gt(test.tn.c, 50)", " └─IndexRangeScan_12 2.50 cop[tikv] table:tn, index:a(a, b, c, d) range:(1 10,1 20), keep order:false, stats:pseudo" ], - "Warning": "" + "Warning": null }, { "SQL": "select * from tn where a = 1 and b > 10 and b < 20 and c > 50 order by d limit 1", @@ -1466,7 +1466,7 @@ " └─Selection_19 0.83 cop[tikv] gt(test.tn.c, 50)", " └─IndexRangeScan_18 2.50 cop[tikv] table:tn, index:a(a, b, c, d) range:(1 10,1 20), keep order:false, stats:pseudo" ], - "Warning": "" + "Warning": null }, { "SQL": "select /*+ LIMIT_TO_COP() */ a from tn where mod(a, 2) order by a limit 1", @@ -1476,7 +1476,10 @@ " └─IndexReader_21 1.00 root index:IndexFullScan_20", " └─IndexFullScan_20 1.00 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo" ], - "Warning": "[planner:1815]Optimizer Hint LIMIT_TO_COP is inapplicable" + "Warning": [ + "Scalar function 'mod'(signature: ModInt) can not be pushed to storage layer", + "[planner:1815]Optimizer Hint LIMIT_TO_COP is inapplicable" + ] } ] }, diff --git a/planner/core/testdata/plan_suite_unexported_out.json b/planner/core/testdata/plan_suite_unexported_out.json index fd7cc4aa95ae4..bc5dbe6d230d3 100644 --- a/planner/core/testdata/plan_suite_unexported_out.json +++ b/planner/core/testdata/plan_suite_unexported_out.json @@ -183,11 +183,11 @@ { "Name": "TestWindowFunction", "Cases": [ - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Projection", "IndexReader(Index(t.f)[[NULL,+inf]])->Projection->Sort->Window(avg(cast(Column#16, decimal(24,4) BINARY))->Column#17 over(partition by Column#15))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "[planner:1054]Unknown column 'z' in 'field list'", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#14 over())->Sort->Projection", "IndexReader(Index(t.f)[[NULL,+inf]]->StreamAgg)->StreamAgg->Window(sum(Column#13)->Column#15 over())->Projection", @@ -206,7 +206,7 @@ "IndexReader(Index(t.f)[[NULL,+inf]])->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(rows between 1 preceding and 1 following))->Projection", "[planner:3583]Window '' cannot inherit 'w' since both contain an ORDER BY clause.", "[planner:3591]Window 'w1' is defined twice.", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(partition by test.t.a))->Sort->Projection", "[planner:1235]This version of TiDB doesn't yet support 'GROUPS'", "[planner:3584]Window '': frame start cannot be UNBOUNDED FOLLOWING.", @@ -227,7 +227,7 @@ "[planner:1210]Incorrect arguments to nth_value", "[planner:1210]Incorrect arguments to ntile", "IndexReader(Index(t.f)[[NULL,+inf]])->Window(ntile()->Column#14 over())->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Projection", "TableReader(Table(t))->Window(nth_value(test.t.i_date, 1)->Column#14 over())->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#15, sum(cast(test.t.c, decimal(65,0) BINARY))->Column#16 over(order by test.t.a asc range between unbounded preceding and current row))->Projection", "[planner:3593]You cannot use the window function 'sum' in this context.'", @@ -256,11 +256,11 @@ { "Name": "TestWindowParallelFunction", "Cases": [ - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", "IndexReader(Index(t.f)[[NULL,+inf]])->Projection->Sort->Window(avg(cast(Column#16, decimal(24,4) BINARY))->Column#17 over(partition by Column#15))->Partition(execution info: concurrency:4, data source:Projection_8)->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "[planner:1054]Unknown column 'z' in 'field list'", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#14 over())->Sort->Projection", "IndexReader(Index(t.f)[[NULL,+inf]]->StreamAgg)->StreamAgg->Window(sum(Column#13)->Column#15 over())->Projection", @@ -279,7 +279,7 @@ "IndexReader(Index(t.f)[[NULL,+inf]])->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(rows between 1 preceding and 1 following))->Projection", "[planner:3583]Window '' cannot inherit 'w' since both contain an ORDER BY clause.", "[planner:3591]Window 'w1' is defined twice.", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(partition by test.t.a))->Sort->Projection", "[planner:1235]This version of TiDB doesn't yet support 'GROUPS'", "[planner:3584]Window '': frame start cannot be UNBOUNDED FOLLOWING.", @@ -300,7 +300,7 @@ "[planner:1210]Incorrect arguments to nth_value", "[planner:1210]Incorrect arguments to ntile", "IndexReader(Index(t.f)[[NULL,+inf]])->Window(ntile()->Column#14 over())->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", "TableReader(Table(t))->Window(nth_value(test.t.i_date, 1)->Column#14 over())->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#15, sum(cast(test.t.c, decimal(65,0) BINARY))->Column#16 over(order by test.t.a asc range between unbounded preceding and current row))->Projection", "[planner:3593]You cannot use the window function 'sum' in this context.'", diff --git a/planner/optimize.go b/planner/optimize.go index 35c5981d36eb7..bf08b87813124 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -124,6 +124,10 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in sctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("sql_select_limit is set, so plan binding is not activated")) return bestPlan, names, nil } + err = setFoundInBinding(sctx, true) + if err != nil { + return nil, nil, err + } bestPlanHint := plannercore.GenHintsFromPhysicalPlan(bestPlan) if len(bindRecord.Bindings) > 0 { orgBinding := bindRecord.Bindings[0] // the first is the original binding @@ -248,7 +252,10 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in finalPlan, cost, err := cascades.DefaultOptimizer.FindBestPlan(sctx, logic) return finalPlan, names, cost, err } + + beginOpt := time.Now() finalPlan, cost, err := plannercore.DoOptimize(ctx, sctx, builder.GetOptFlag(), logic) + sctx.GetSessionVars().DurationOptimization = time.Since(beginOpt) return finalPlan, names, cost, err } @@ -266,12 +273,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) switch x.Stmt.(type) { case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt: plannercore.EraseLastSemicolon(x) - var normalizeExplainSQL string - if specifiledDB != "" { - normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB)) - } else { - normalizeExplainSQL = parser.Normalize(x.Text()) - } + normalizeExplainSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB)) idx := int(0) switch n := x.Stmt.(type) { case *ast.SelectStmt: @@ -301,12 +303,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) if len(x.Text()) == 0 { return x, "", "" } - var normalizedSQL, hash string - if specifiledDB != "" { - normalizedSQL, hash = parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB)) - } else { - normalizedSQL, hash = parser.NormalizeDigest(x.Text()) - } + normalizedSQL, hash := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB)) return x, normalizedSQL, hash } return nil, "", "" @@ -492,6 +489,12 @@ func handleStmtHints(hints []*ast.TableOptimizerHint) (stmtHints stmtctx.StmtHin return } +func setFoundInBinding(sctx sessionctx.Context, opt bool) error { + vars := sctx.GetSessionVars() + err := vars.SetSystemVar(variable.TiDBFoundInBinding, variable.BoolToIntStr(opt)) + return err +} + func init() { plannercore.OptimizeAstNode = Optimize } diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 7d39c46f31843..6c59e6aa3970b 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -47,6 +47,23 @@ var ( const globalDBVisible = mysql.CreatePriv | mysql.SelectPriv | mysql.InsertPriv | mysql.UpdatePriv | mysql.DeletePriv | mysql.ShowDBPriv | mysql.DropPriv | mysql.AlterPriv | mysql.IndexPriv | mysql.CreateViewPriv | mysql.ShowViewPriv | mysql.GrantPriv | mysql.TriggerPriv | mysql.ReferencesPriv | mysql.ExecutePriv +const ( + sqlLoadRoleGraph = "SELECT HIGH_PRIORITY FROM_USER, FROM_HOST, TO_USER, TO_HOST FROM mysql.role_edges" + sqlLoadGlobalPrivTable = "SELECT HIGH_PRIORITY Host,User,Priv FROM mysql.global_priv" + sqlLoadDBTable = "SELECT HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv FROM mysql.db ORDER BY host, db, user" + sqlLoadTablePrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv FROM mysql.tables_priv" + sqlLoadColumnsPrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv FROM mysql.columns_priv" + sqlLoadDefaultRoles = "SELECT HIGH_PRIORITY HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER FROM mysql.default_roles" + // list of privileges from mysql.Priv2UserCol + sqlLoadUserTable = `SELECT HIGH_PRIORITY Host,User,authentication_string, + Create_priv, Select_priv, Insert_priv, Update_priv, Delete_priv, Show_db_priv, Super_priv, + Create_user_priv,Trigger_priv,Drop_priv,Process_priv,Grant_priv, + References_priv,Alter_priv,Execute_priv,Index_priv,Create_view_priv,Show_view_priv, + Create_role_priv,Drop_role_priv,Create_tmp_table_priv,Lock_tables_priv,Create_routine_priv, + Alter_routine_priv,Event_priv,Shutdown_priv,Reload_priv,File_priv,Config_priv, + account_locked FROM mysql.user` +) + func computePrivMask(privs []mysql.PrivilegeType) mysql.PrivilegeType { var mask mysql.PrivilegeType for _, p := range privs { @@ -347,7 +364,7 @@ func noSuchTable(err error) bool { // LoadRoleGraph loads the mysql.role_edges table from database. func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error { p.RoleGraph = make(map[string]roleGraphEdgesTable) - err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from mysql.role_edges;", p.decodeRoleEdgesTable) + err := p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable) if err != nil { return errors.Trace(err) } @@ -356,12 +373,7 @@ func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error { // LoadUserTable loads the mysql.user table from database. func (p *MySQLPrivilege) LoadUserTable(ctx sessionctx.Context) error { - userPrivCols := make([]string, 0, len(mysql.Priv2UserCol)) - for _, v := range mysql.Priv2UserCol { - userPrivCols = append(userPrivCols, v) - } - query := fmt.Sprintf("select HIGH_PRIORITY Host,User,authentication_string,%s,account_locked from mysql.user;", strings.Join(userPrivCols, ", ")) - err := p.loadTable(ctx, query, p.decodeUserTableRow) + err := p.loadTable(ctx, sqlLoadUserTable, p.decodeUserTableRow) if err != nil { return errors.Trace(err) } @@ -467,12 +479,12 @@ func (p MySQLPrivilege) SortUserTable() { // LoadGlobalPrivTable loads the mysql.global_priv table from database. func (p *MySQLPrivilege) LoadGlobalPrivTable(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Priv from mysql.global_priv", p.decodeGlobalPrivTableRow) + return p.loadTable(ctx, sqlLoadGlobalPrivTable, p.decodeGlobalPrivTableRow) } // LoadDBTable loads the mysql.db table from database. func (p *MySQLPrivilege) LoadDBTable(ctx sessionctx.Context) error { - err := p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv from mysql.db order by host, db, user;", p.decodeDBTableRow) + err := p.loadTable(ctx, sqlLoadDBTable, p.decodeDBTableRow) if err != nil { return err } @@ -490,7 +502,7 @@ func (p *MySQLPrivilege) buildDBMap() { // LoadTablesPrivTable loads the mysql.tables_priv table from database. func (p *MySQLPrivilege) LoadTablesPrivTable(ctx sessionctx.Context) error { - err := p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv from mysql.tables_priv", p.decodeTablesPrivTableRow) + err := p.loadTable(ctx, sqlLoadTablePrivTable, p.decodeTablesPrivTableRow) if err != nil { return err } @@ -508,24 +520,22 @@ func (p *MySQLPrivilege) buildTablesPrivMap() { // LoadColumnsPrivTable loads the mysql.columns_priv table from database. func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow) + return p.loadTable(ctx, sqlLoadColumnsPrivTable, p.decodeColumnsPrivTableRow) } // LoadDefaultRoles loads the mysql.columns_priv table from database. func (p *MySQLPrivilege) LoadDefaultRoles(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from mysql.default_roles", p.decodeDefaultRoleTableRow) + return p.loadTable(ctx, sqlLoadDefaultRoles, p.decodeDefaultRoleTableRow) } func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, decodeTableRow func(chunk.Row, []*ast.ResultField) error) error { ctx := context.Background() - tmp, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } - rs := tmp[0] defer terror.Call(rs.Close) - fs := rs.Fields() req := rs.NewChunk() for { diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 812e39be1b2aa..247e77e7bbcc5 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -156,6 +156,31 @@ func (s *testPrivilegeSuite) TestCheckPointGetDBPrivilege(c *C) { c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) } +func (s *testPrivilegeSuite) TestIssue22946(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, "create database db1;") + mustExec(c, rootSe, "create database db2;") + mustExec(c, rootSe, "use test;") + mustExec(c, rootSe, "create table a(id int);") + mustExec(c, rootSe, "use db1;") + mustExec(c, rootSe, "create table a(id int primary key,name varchar(20));") + mustExec(c, rootSe, "use db2;") + mustExec(c, rootSe, "create table b(id int primary key,address varchar(50));") + mustExec(c, rootSe, "CREATE USER 'delTest'@'localhost';") + mustExec(c, rootSe, "grant all on db1.* to delTest@'localhost';") + mustExec(c, rootSe, "grant all on db2.* to delTest@'localhost';") + mustExec(c, rootSe, "grant select on test.* to delTest@'localhost';") + mustExec(c, rootSe, "flush privileges;") + + se := newSession(c, s.store, s.dbName) + c.Assert(se.Auth(&auth.UserIdentity{Username: "delTest", Hostname: "localhost"}, nil, nil), IsTrue) + _, err := se.ExecuteInternal(context.Background(), `delete from db1.a as A where exists(select 1 from db2.b as B where A.id = B.id);`) + c.Assert(err, IsNil) + mustExec(c, rootSe, "use db1;") + _, err = se.ExecuteInternal(context.Background(), "delete from test.a as A;") + c.Assert(terror.ErrorEqual(err, core.ErrPrivilegeCheckFail), IsTrue) +} + func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { rootSe := newSession(c, s.store, s.dbName) mustExec(c, rootSe, `CREATE USER 'test1'@'localhost';`) @@ -877,6 +902,45 @@ func (s *testPrivilegeSuite) TestShowCreateTable(c *C) { mustExec(c, se, `SHOW CREATE TABLE mysql.user`) } +func (s *testPrivilegeSuite) TestReplaceAndInsertOnDuplicate(c *C) { + se := newSession(c, s.store, s.dbName) + mustExec(c, se, `CREATE USER tr_insert`) + mustExec(c, se, `CREATE USER tr_update`) + mustExec(c, se, `CREATE USER tr_delete`) + mustExec(c, se, `CREATE TABLE t1 (a int primary key, b int)`) + mustExec(c, se, `GRANT INSERT ON t1 TO tr_insert`) + mustExec(c, se, `GRANT UPDATE ON t1 TO tr_update`) + mustExec(c, se, `GRANT DELETE ON t1 TO tr_delete`) + + // Restrict the permission to INSERT only. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_insert", Hostname: "localhost", AuthUsername: "tr_insert", AuthHostname: "%"}, nil, nil), IsTrue) + + // REPLACE requires INSERT + DELETE privileges, having INSERT alone is insufficient. + _, err := se.ExecuteInternal(context.Background(), `REPLACE INTO t1 VALUES (1, 2)`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]DELETE command denied to user 'tr_insert'@'%' for table 't1'") + + // INSERT ON DUPLICATE requires INSERT + UPDATE privileges, having INSERT alone is insufficient. + _, err = se.ExecuteInternal(context.Background(), `INSERT INTO t1 VALUES (3, 4) ON DUPLICATE KEY UPDATE b = 5`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]UPDATE command denied to user 'tr_insert'@'%' for table 't1'") + + // Plain INSERT should work. + mustExec(c, se, `INSERT INTO t1 VALUES (6, 7)`) + + // Also check that having DELETE alone is insufficient for REPLACE. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_delete", Hostname: "localhost", AuthUsername: "tr_delete", AuthHostname: "%"}, nil, nil), IsTrue) + _, err = se.ExecuteInternal(context.Background(), `REPLACE INTO t1 VALUES (8, 9)`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]INSERT command denied to user 'tr_delete'@'%' for table 't1'") + + // Also check that having UPDATE alone is insufficient for INSERT ON DUPLICATE. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_update", Hostname: "localhost", AuthUsername: "tr_update", AuthHostname: "%"}, nil, nil), IsTrue) + _, err = se.ExecuteInternal(context.Background(), `INSERT INTO t1 VALUES (10, 11) ON DUPLICATE KEY UPDATE b = 12`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]INSERT command denied to user 'tr_update'@'%' for table 't1'") +} + func (s *testPrivilegeSuite) TestAnalyzeTable(c *C) { se := newSession(c, s.store, s.dbName) diff --git a/server/conn.go b/server/conn.go index b0d46538958e9..e8687a8a54c28 100644 --- a/server/conn.go +++ b/server/conn.go @@ -879,35 +879,39 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { sqlType = stmtType } + cost := time.Since(startTime) + sessionVar := cc.ctx.GetSessionVars() + cc.ctx.GetTxnWriteThroughputSLI().FinishExecuteStmt(cost, cc.ctx.AffectedRows(), sessionVar.InTxn()) + switch sqlType { case "Use": - queryDurationHistogramUse.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramUse.Observe(cost.Seconds()) case "Show": - queryDurationHistogramShow.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramShow.Observe(cost.Seconds()) case "Begin": - queryDurationHistogramBegin.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramBegin.Observe(cost.Seconds()) case "Commit": - queryDurationHistogramCommit.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramCommit.Observe(cost.Seconds()) case "Rollback": - queryDurationHistogramRollback.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramRollback.Observe(cost.Seconds()) case "Insert": - queryDurationHistogramInsert.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramInsert.Observe(cost.Seconds()) case "Replace": - queryDurationHistogramReplace.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramReplace.Observe(cost.Seconds()) case "Delete": - queryDurationHistogramDelete.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramDelete.Observe(cost.Seconds()) case "Update": - queryDurationHistogramUpdate.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramUpdate.Observe(cost.Seconds()) case "Select": - queryDurationHistogramSelect.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramSelect.Observe(cost.Seconds()) case "Execute": - queryDurationHistogramExecute.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramExecute.Observe(cost.Seconds()) case "Set": - queryDurationHistogramSet.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramSet.Observe(cost.Seconds()) case metrics.LblGeneral: - queryDurationHistogramGeneral.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramGeneral.Observe(cost.Seconds()) default: - metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(time.Since(startTime).Seconds()) + metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(cost.Seconds()) } } diff --git a/server/conn_stmt.go b/server/conn_stmt.go index c8e8545c43d08..402f4aae61905 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" @@ -112,6 +113,11 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error { func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err error) { defer trace.StartRegion(ctx, "HandleStmtExecute").End() + defer func() { + if err != nil { + metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() + } + }() if len(data) < 9 { return mysql.ErrMalformPacket } diff --git a/server/driver.go b/server/driver.go index 389f9f2700173..b8f95984849ac 100644 --- a/server/driver.go +++ b/server/driver.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sli" ) // IDriver opens IContext. @@ -100,6 +101,9 @@ type QueryCtx interface { SetCommandValue(command byte) SetSessionManager(util.SessionManager) + + // GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI. + GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI } // PreparedStatement is the interface to use a prepared statement. diff --git a/server/driver_tidb.go b/server/driver_tidb.go index e677310716668..4959bda9faeb9 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" ) @@ -367,6 +368,11 @@ func (tc *TiDBContext) GetSessionVars() *variable.SessionVars { return tc.session.GetSessionVars() } +// GetTxnWriteThroughputSLI implements QueryCtx GetTxnWriteThroughputSLI method. +func (tc *TiDBContext) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return tc.session.GetTxnWriteThroughputSLI() +} + type tidbResultSet struct { recordSet sqlexec.RecordSet columns []*ColumnInfo diff --git a/server/server.go b/server/server.go index 26b6f310848a6..ee9ab43f7eadc 100644 --- a/server/server.go +++ b/server/server.go @@ -172,8 +172,11 @@ func (s *Server) newConn(conn net.Conn) *clientConn { return cc } +// isUnixSocket should ideally be a function of clientConnection! +// But currently since unix-socket connections are forwarded to TCP when the server listens on both, it can really only be accurate on a server-level. +// If the server is listening on both, it *must* return FALSE for remote-host authentication to be performed correctly. See #23460. func (s *Server) isUnixSocket() bool { - return s.cfg.Socket != "" + return s.cfg.Socket != "" && s.cfg.Port == 0 } func (s *Server) forwardUnixSocketToTCP() { @@ -449,6 +452,10 @@ func (s *Server) onConn(conn *clientConn) { conn.Run(ctx) err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + // Audit plugin may be disabled before a conn is created, leading no connectionInfo in sessionVars. + if sessionVars.ConnectionInfo == nil { + sessionVars.ConnectionInfo = conn.connectInfo() + } authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) diff --git a/server/server_test.go b/server/server_test.go index f70c8e2fbda38..9c672162c3555 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -903,6 +903,74 @@ func (cli *testServerClient) runTestLoadData(c *C, server *Server) { }) } +func (cli *testServerClient) runTestLoadDataAutoRandom(c *C) { + path := "/tmp/load_data_txn_error.csv" + + fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + c.Assert(err, IsNil) + c.Assert(fp, NotNil) + + defer func() { + _ = os.Remove(path) + }() + + cksum1 := 0 + cksum2 := 0 + for i := 0; i < 50000; i++ { + n1 := rand.Intn(1000) + n2 := rand.Intn(1000) + str1 := strconv.Itoa(n1) + str2 := strconv.Itoa(n2) + row := str1 + "\t" + str2 + _, err := fp.WriteString(row) + c.Assert(err, IsNil) + _, err = fp.WriteString("\n") + c.Assert(err, IsNil) + + if i == 0 { + cksum1 = n1 + cksum2 = n2 + } else { + cksum1 = cksum1 ^ n1 + cksum2 = cksum2 ^ n2 + } + } + + err = fp.Close() + c.Assert(err, IsNil) + + cli.runTestsOnNewDB(c, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params = map[string]string{"sql_mode": "''"} + }, "load_data_batch_dml", func(dbt *DBTest) { + // Set batch size, and check if load data got a invalid txn error. + dbt.mustExec("set @@session.tidb_dml_batch_size = 128") + dbt.mustExec("drop table if exists t") + dbt.mustExec("create table t(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") + dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (c2, c3)", path)) + + var ( + rowCnt int + colCkSum1 int + colCkSum2 int + ) + rows := dbt.mustQuery("select count(*) from t") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&rowCnt) + dbt.Check(err, IsNil) + dbt.Check(rowCnt, DeepEquals, 50000) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + + rows = dbt.mustQuery("select bit_xor(c2), bit_xor(c3) from t") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&colCkSum1, &colCkSum2) + dbt.Check(err, IsNil) + dbt.Check(colCkSum1, DeepEquals, cksum1) + dbt.Check(colCkSum2, DeepEquals, cksum2) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + }) +} + func (cli *testServerClient) runTestConcurrentUpdate(c *C) { dbName := "Concurrent" cli.runTestsOnNewDB(c, func(config *mysql.Config) { diff --git a/server/sql_info_fetcher.go b/server/sql_info_fetcher.go index 34236f8eabe7d..76ba5d6682341 100644 --- a/server/sql_info_fetcher.go +++ b/server/sql_info_fetcher.go @@ -88,7 +88,7 @@ func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request) timeoutString := r.FormValue("timeout") curDB := strings.ToLower(r.FormValue("current_db")) if curDB != "" { - _, err = sh.s.Execute(reqCtx, fmt.Sprintf("use %v", curDB)) + _, err = sh.s.ExecuteInternal(reqCtx, "use %n", curDB) if err != nil { serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err)) return diff --git a/server/statistics_handler.go b/server/statistics_handler.go index a40e1b19b321f..ecb246e46cf48 100644 --- a/server/statistics_handler.go +++ b/server/statistics_handler.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/gcutil" - "github.com/pingcap/tidb/util/sqlexec" ) // StatsHandler is the handler for dumping statistics. @@ -122,9 +121,7 @@ func (sh StatsHistoryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request writeError(w, err) return } - se.GetSessionVars().SnapshotInfoschema, se.GetSessionVars().SnapshotTS = is, snapshot - historyStatsExec := se.(sqlexec.RestrictedSQLExecutor) - js, err := h.DumpStatsToJSON(params[pDBName], tbl.Meta(), historyStatsExec) + js, err := h.DumpStatsToJSONBySnapshot(params[pDBName], tbl.Meta(), snapshot) if err != nil { writeError(w, err) } else { diff --git a/server/tidb_test.go b/server/tidb_test.go index 9290cd35b6d75..5bafaa137c382 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -148,6 +148,12 @@ func (ts *tidbTestSerialSuite) TestLoadData(c *C) { ts.runTestLoadDataForSlowLog(c, ts.server) } +// Fix issue#22540. Change tidb_dml_batch_size, +// then check if load data into table with auto random column works properly. +func (ts *tidbTestSerialSuite) TestLoadDataAutoRandom(c *C) { + ts.runTestLoadDataAutoRandom(c) +} + func (ts *tidbTestSerialSuite) TestExplainFor(c *C) { ts.runTestExplainForConn(c) } @@ -242,6 +248,7 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsFalse) // If listening on tcp-only, return FALSE // https connection should work. ts.runTestStatusAPI(c) @@ -342,6 +349,7 @@ func (ts *tidbTestSuite) TestSocketForwarding(c *C) { cli.port = getPortFromTCPAddr(server.listener.Addr()) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsFalse) // If listening on both, return FALSE defer server.Close() cli.runTestRegression(c, func(config *mysql.Config) { @@ -365,6 +373,7 @@ func (ts *tidbTestSuite) TestSocket(c *C) { c.Assert(err, IsNil) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsTrue) // If listening on socket-only, return TRUE defer server.Close() //a fake server client, config is override, just used to run tests diff --git a/session/bootstrap.go b/session/bootstrap.go index b85bb79e7de13..ace9d791bf5d6 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -465,7 +465,7 @@ var ( func checkBootstrapped(s Session) (bool, error) { // Check if system db exists. - _, err := s.Execute(context.Background(), fmt.Sprintf("USE %s;", mysql.SystemDB)) + _, err := s.ExecuteInternal(context.Background(), "USE %n", mysql.SystemDB) if err != nil && infoschema.ErrDatabaseNotExists.NotEqual(err) { logutil.BgLogger().Fatal("check bootstrap error", zap.Error(err)) @@ -491,20 +491,18 @@ func checkBootstrapped(s Session) (bool, error) { // getTiDBVar gets variable value from mysql.tidb table. // Those variables are used by TiDB server. func getTiDBVar(s Session, name string) (sVal string, isNull bool, e error) { - sql := fmt.Sprintf(`SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s"`, - mysql.SystemDB, mysql.TiDBTable, name) ctx := context.Background() - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME= %?`, + mysql.SystemDB, + mysql.TiDBTable, + name, + ) if err != nil { return "", true, errors.Trace(err) } - if len(rs) != 1 { - return "", true, errors.New("Wrong number of Recordset") - } - r := rs[0] - defer terror.Call(r.Close) - req := r.NewChunk() - err = r.Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil || req.NumRows() == 0 { return "", true, errors.Trace(err) } @@ -530,7 +528,7 @@ func upgrade(s Session) { } updateBootstrapVer(s) - _, err = s.Execute(context.Background(), "COMMIT") + _, err = s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second @@ -622,7 +620,7 @@ func upgradeToVer8(s Session, ver int64) { return } // This is a dummy upgrade, it checks whether upgradeToVer7 success, if not, do it again. - if _, err := s.Execute(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` from mysql.user limit 0"); err == nil { + if _, err := s.ExecuteInternal(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` from mysql.user limit 0"); err == nil { return } upgradeToVer7(s, ver) @@ -638,7 +636,7 @@ func upgradeToVer9(s Session, ver int64) { } func doReentrantDDL(s Session, sql string, ignorableErrs ...error) { - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), sql) for _, ignorableErr := range ignorableErrs { if terror.ErrorEqual(err, ignorableErr) { return @@ -664,7 +662,7 @@ func upgradeToVer11(s Session, ver int64) { if ver >= version11 { return } - _, err := s.Execute(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") + _, err := s.ExecuteInternal(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { return @@ -679,21 +677,20 @@ func upgradeToVer12(s Session, ver int64) { return } ctx := context.Background() - _, err := s.Execute(ctx, "BEGIN") + _, err := s.ExecuteInternal(ctx, "BEGIN") terror.MustNil(err) sql := "SELECT HIGH_PRIORITY user, host, password FROM mysql.user WHERE password != ''" - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, sql) if terror.ErrorEqual(err, core.ErrUnknownColumn) { sql := "SELECT HIGH_PRIORITY user, host, authentication_string FROM mysql.user WHERE authentication_string != ''" - rs, err = s.Execute(ctx, sql) + rs, err = s.ExecuteInternal(ctx, sql) } terror.MustNil(err) - r := rs[0] sqls := make([]string, 0, 1) - defer terror.Call(r.Close) - req := r.NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) - err = r.Next(ctx, req) + err = rs.Next(ctx, req) for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { user := row.GetString(0) @@ -705,7 +702,7 @@ func upgradeToVer12(s Session, ver int64) { updateSQL := fmt.Sprintf(`UPDATE HIGH_PRIORITY mysql.user set password = "%s" where user="%s" and host="%s"`, newPass, user, host) sqls = append(sqls, updateSQL) } - err = r.Next(ctx, req) + err = rs.Next(ctx, req) } terror.MustNil(err) @@ -735,7 +732,7 @@ func upgradeToVer13(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -764,7 +761,7 @@ func upgradeToVer14(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -779,7 +776,7 @@ func upgradeToVer15(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGCDeleteRangeTable) + _, err = s.ExecuteInternal(context.Background(), CreateGCDeleteRangeTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer15 error", zap.Error(err)) } @@ -849,9 +846,13 @@ func upgradeToVer23(s Session, ver int64) { // writeSystemTZ writes system timezone info into mysql.tidb func writeSystemTZ(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%s", "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, tidbSystemTZ, timeutil.InferSystemTZ(), timeutil.InferSystemTZ()) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, + mysql.SystemDB, + mysql.TiDBTable, + tidbSystemTZ, + timeutil.InferSystemTZ(), + timeutil.InferSystemTZ(), + ) } // upgradeToVer24 initializes `System` timezone according to docs/design/2018-09-10-adding-tz-env.md @@ -980,7 +981,7 @@ func upgradeToVer38(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGlobalPrivTable) + _, err = s.ExecuteInternal(context.Background(), CreateGlobalPrivTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer38 error", zap.Error(err)) } @@ -1002,9 +1003,9 @@ func writeNewCollationParameter(s Session, flag bool) { if flag { b = varTrue } - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%s', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%s'`, - mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b, + ) } func upgradeToVer40(s Session, ver int64) { @@ -1040,14 +1041,14 @@ func upgradeToVer42(s Session, ver int64) { // Convert statement summary global variables to non-empty values. func writeStmtSummaryVars(s Session) { - sql := fmt.Sprintf("UPDATE %s.%s SET variable_value='%%s' WHERE variable_name='%%s' AND variable_value=''", mysql.SystemDB, mysql.GlobalVariablesTable) + sql := "UPDATE mysql.global_variables SET variable_value= %? WHERE variable_name= %? AND variable_value=''" stmtSummaryConfig := config.GetGlobalConfig().StmtSummary - mustExecute(s, fmt.Sprintf(sql, variable.BoolToIntStr(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary)) - mustExecute(s, fmt.Sprintf(sql, variable.BoolToIntStr(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength)) + mustExecute(s, sql, variable.BoolToIntStr(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary) + mustExecute(s, sql, variable.BoolToIntStr(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery) + mustExecute(s, sql, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval) + mustExecute(s, sql, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize) + mustExecute(s, sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount) + mustExecute(s, sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength) } func upgradeToVer43(s Session, ver int64) { @@ -1128,9 +1129,9 @@ func initBindInfoTable(s Session) { } func insertBuiltinBindInfoRow(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES ("%s", "%s", "mysql", "%s", "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", "%s")`, - bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES (%?, %?, "mysql", %?, "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", %?)`, + bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin, + ) } type bindInfo struct { @@ -1160,8 +1161,8 @@ func upgradeToVer51(s Session, ver int64) { mustExecute(s, "COMMIT") }() mustExecute(s, h.LockBindInfoSQL()) - var recordSets []sqlexec.RecordSet - recordSets, err = s.ExecuteInternal(context.Background(), + var rs sqlexec.RecordSet + rs, err = s.ExecuteInternal(context.Background(), `SELECT bind_sql, default_db, status, create_time, charset, collation, source FROM mysql.bind_info WHERE source != 'builtin' @@ -1169,15 +1170,13 @@ func upgradeToVer51(s Session, ver int64) { if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } - if len(recordSets) > 0 { - defer terror.Call(recordSets[0].Close) - } - req := recordSets[0].NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() iter := chunk.NewIterator4Chunk(req) p := parser.New() now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) for { - err = recordSets[0].Next(context.TODO(), req) + err = rs.Next(context.TODO(), req) if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } @@ -1235,17 +1234,17 @@ func updateBindInfo(iter *chunk.Iterator4Chunk, p *parser.Parser, bindMap map[st func writeMemoryQuotaQuery(s Session) { comment := "memory_quota_query is 32GB by default in v3.0.x, 1GB by default in v4.0.x" - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%d', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%d'`, - mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30, + ) } // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%d", "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%d"`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion, + ) } // getBootstrapVersion gets bootstrap version from mysql.tidb table; @@ -1265,7 +1264,7 @@ func doDDLWorks(s Session) { // Create a test database. mustExecute(s, "CREATE DATABASE IF NOT EXISTS test") // Create system db. - mustExecute(s, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysql.SystemDB)) + mustExecute(s, "CREATE DATABASE IF NOT EXISTS %n", mysql.SystemDB) // Create user table. mustExecute(s, CreateUserTable) // Create privilege tables. @@ -1334,14 +1333,13 @@ func doDMLWorks(s Session) { strings.Join(values, ", ")) mustExecute(s, sql) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%s", "Bootstrap flag. Do not delete.") - ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap flag. Do not delete.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue, + ) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%d", "Bootstrap version. Do not delete.")`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap version. Do not delete.")`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, + ) writeSystemTZ(s) @@ -1351,7 +1349,7 @@ func doDMLWorks(s Session) { writeStmtSummaryVars(s) - _, err := s.Execute(context.Background(), "COMMIT") + _, err := s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second logutil.BgLogger().Info("doDMLWorks failed", zap.Error(err), zap.Duration("sleeping time", sleepTime)) @@ -1368,8 +1366,8 @@ func doDMLWorks(s Session) { } } -func mustExecute(s Session, sql string) { - _, err := s.Execute(context.Background(), sql) +func mustExecute(s Session, sql string, args ...interface{}) { + _, err := s.ExecuteInternal(context.Background(), sql, args...) if err != nil { debug.PrintStack() logutil.BgLogger().Fatal("mustExecute error", zap.Error(err)) diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index e587aadbc8f74..2a59620a98d28 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -54,7 +54,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) datums := statistics.RowToDatums(req.GetRow(0), r.Fields()) - match(c, datums, `%`, "root", []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") + match(c, datums, `%`, "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") @@ -159,7 +159,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { c.Assert(req.NumRows() == 0, IsFalse) row := req.GetRow(0) datums := statistics.RowToDatums(row, r.Fields()) - match(c, datums, `%`, "root", []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") + match(c, datums, `%`, "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") c.Assert(r.Close(), IsNil) mustExecSQL(c, se, "USE test;") diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index d68c9b06327a4..536b7feaab31d 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -299,6 +299,14 @@ func (s *testPessimisticSuite) TestInsertOnDup(c *C) { tk.MustQuery("select * from dup").Check(testkit.Rows("1 2")) } +func (s *testPessimisticSuite) TestPointGetOverflow(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t(k tinyint, v int, unique key(k))") + tk.MustExec("begin pessimistic") + tk.MustExec("update t set v = 100 where k = -200;") + tk.MustExec("update t set v = 100 where k in (-200, -400);") +} + func (s *testPessimisticSuite) TestPointGetKeyLock(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk2 := testkit.NewTestKitWithInit(c, s.store) diff --git a/session/session.go b/session/session.go index 8912f55795c84..45690f166d689 100644 --- a/session/session.go +++ b/session/session.go @@ -70,6 +70,7 @@ import ( "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tipb/go-binlog" @@ -100,10 +101,12 @@ type Session interface { LastMessage() string // LastMessage is the info message that may be generated by last command AffectedRows() uint64 // Affected rows by latest executed stmt. // Execute is deprecated, use ExecuteStmt() instead. - Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. - ExecuteInternal(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a internal sql statement. + Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) + // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) + // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. + ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error RollbackTxn(context.Context) @@ -858,37 +861,17 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) startTime := time.Now() - recordSets, err := se.Execute(ctx, sql) - defer func() { - for _, rs := range recordSets { - closeErr := rs.Close() - if closeErr != nil && err == nil { - err = closeErr - } - } - }() - if err != nil { + rs, err := se.ExecuteInternal(ctx, sql) + if err != nil || rs == nil { return nil, nil, err } - - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - // Execute all recordset, take out the first one as result. - for i, rs := range recordSets { - tmp, err := drainRecordSet(ctx, se, rs) - if err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } + defer terror.Call(rs.Close) + rows, err := drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err } metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) - return rows, fields, err + return rows, rs.Fields(), err } func createSessionFunc(store kv.Storage) pools.Factory { @@ -951,15 +934,19 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet) ([]c } } -// getExecRet executes restricted sql and the result is one column. +// getTableValue executes restricted sql and the result is one column. // It returns a string value. -func (s *session) getExecRet(ctx sessionctx.Context, sql string) (string, error) { - rows, fields, err := s.ExecRestrictedSQL(sql) +func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { + stmt, err := s.ParseWithParams(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + if err != nil { + return "", err + } + rows, fields, err := s.ExecRestrictedStmt(ctx, stmt) if err != nil { return "", err } if len(rows) == 0 { - return "", executor.ErrResultIsEmpty + return "", errResultIsEmpty } d := rows[0].GetDatum(0, &fields[0].Column.FieldType) value, err := d.ToString() @@ -974,9 +961,11 @@ func (s *session) GetAllSysVars() (map[string]string, error) { if s.Value(sessionctx.Initing) != nil { return nil, nil } - sql := `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %s.%s;` - sql = fmt.Sprintf(sql, mysql.SystemDB, mysql.GlobalVariablesTable) - rows, _, err := s.ExecRestrictedSQL(sql) + stmt, err := s.ParseWithParams(context.TODO(), `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %n.%n`, mysql.SystemDB, mysql.GlobalVariablesTable) + if err != nil { + return nil, err + } + rows, _, err := s.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, err } @@ -997,11 +986,9 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { // When running bootstrap or upgrade, we should not access global storage. return "", nil } - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, - mysql.SystemDB, mysql.GlobalVariablesTable, name) - sysVar, err := s.getExecRet(s, sql) + sysVar, err := s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) if err != nil { - if executor.ErrResultIsEmpty.Equal(err) { + if errResultIsEmpty.Equal(err) { if sv, ok := variable.SysVars[name]; ok { return sv.Value, nil } @@ -1030,9 +1017,11 @@ func (s *session) SetGlobalSysVar(name, value string) error { return err } name = strings.ToLower(name) - sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) - _, _, err = s.ExecRestrictedSQL(sql) + stmt, err := s.ParseWithParams(context.TODO(), "REPLACE %n.%n VALUES (%?, %?)", mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) + if err != nil { + return err + } + _, _, err = s.ExecRestrictedStmt(context.TODO(), stmt) return err } @@ -1121,13 +1110,30 @@ func (rs *execStmtResult) Close() error { return finishStmt(context.Background(), se, err, rs.sql) } -func (s *session) ExecuteInternal(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) { origin := s.sessionVars.InRestrictedSQL s.sessionVars.InRestrictedSQL = true defer func() { s.sessionVars.InRestrictedSQL = origin }() - return s.Execute(ctx, sql) + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("session.ExecuteInternal", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + logutil.Eventf(ctx, "execute: %s", sql) + } + + stmt, err := s.ParseWithParams(ctx, sql, args...) + if err != nil { + return nil, err + } + + rs, err = s.ExecuteStmt(ctx, stmt) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + return rs, err } func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { @@ -1200,6 +1206,130 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) return stmts, nil } +// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. +// Note that it will not do escaping if no variable arguments are passed. +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { + var err error + sql, err = sqlexec.EscapeSQL(sql, args...) + if err != nil { + return nil, err + } + + internal := s.isInternal() + + var stmts []ast.StmtNode + var warns []error + var parseStartTime time.Time + if internal { + // Do no respect the settings from clients, if it is for internal usage. + // Charsets from clients may give chance injections. + // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. + parseStartTime = time.Now() + stmts, warns, err = s.ParseSQL(ctx, sql, mysql.UTF8MB4Charset, mysql.UTF8MB4DefaultCollation) + } else { + charsetInfo, collation := s.sessionVars.GetCharsetInfo() + parseStartTime = time.Now() + stmts, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) + } + if len(stmts) != 1 { + err = errors.New("run multiple statements internally is not supported") + } + if err != nil { + s.rollbackOnError(ctx) + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + if s.sessionVars.EnableRedactLog { + logutil.Logger(ctx).Debug("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } else { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } + } + return nil, util.SyntaxError(err) + } + durParse := time.Since(parseStartTime) + if s.isInternal() { + sessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + sessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + return stmts[0], nil +} + +// ExecRestrictedStmt implements RestrictedSQLExecutor interface. +func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + var execOption sqlexec.ExecOption + for _, opt := range opts { + opt(&execOption) + } + // Use special session to execute the sql. + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, err + } + defer s.sysSessionPool().Put(tmp) + se := tmp.(*session) + + startTime := time.Now() + // The special session will share the `InspectionTableCache` with current session + // if the current session in inspection mode. + if cache := s.sessionVars.InspectionTableCache; cache != nil { + se.sessionVars.InspectionTableCache = cache + defer func() { se.sessionVars.InspectionTableCache = nil }() + } + defer func() { + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + }() + + if execOption.SnapshotTS != 0 { + se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + defer func() { + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + se.sessionVars.SnapshotInfoschema = nil + }() + } + + metrics.SessionRestrictedSQLCounter.Inc() + + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err + } + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) + return rows, rs.Fields(), err +} + func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context())) @@ -1440,6 +1570,9 @@ func (s *session) Txn(active bool) (kv.Transaction, error) { return &s.txn, errors.AddStack(kv.ErrInvalidTxn) } if s.txn.pending() && active { + defer func(begin time.Time) { + s.sessionVars.DurationWaitTS = time.Since(begin) + }(time.Now()) // Transaction is lazy initialized. // PrepareTxnCtx is called to get a tso future, makes s.txn a pending txn, // If Txn() is called later, wait for the future to get a valid txn. @@ -1688,11 +1821,6 @@ func CreateSessionWithOpt(store kv.Storage, opt *Opt) (Session, error) { return s, nil } -// loadSystemTZ loads systemTZ from mysql.tidb -func loadSystemTZ(se *session) (string, error) { - return loadParameter(se, "system_tz") -} - // loadCollationParameter loads collation parameter from mysql.tidb func loadCollationParameter(se *session) (bool, error) { para, err := loadParameter(se, tidbNewCollationEnabled) @@ -1736,25 +1864,7 @@ var ( // loadParameter loads read-only parameter from mysql.tidb func loadParameter(se *session, name string) (string, error) { - sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'" - rss, errLoad := se.Execute(context.Background(), sql) - if errLoad != nil { - return "", errLoad - } - // the record of mysql.tidb under where condition: variable_name = $name should shall only be one. - defer func() { - if err := rss[0].Close(); err != nil { - logutil.BgLogger().Error("close result set error", zap.Error(err)) - } - }() - req := rss[0].NewChunk() - if err := rss[0].Next(context.Background(), req); err != nil { - return "", err - } - if req.NumRows() == 0 { - return "", errResultIsEmpty - } - return req.GetRow(0).GetString(0), nil + return se.getTableValue(context.TODO(), mysql.TiDBTable, name) } // BootstrapSession runs the first time when the TiDB server start. @@ -1786,7 +1896,7 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, err } // get system tz from mysql.tidb - tz, err := loadSystemTZ(se) + tz, err := se.getTableValue(context.TODO(), mysql.TiDBTable, "system_tz") if err != nil { return nil, err } @@ -2335,3 +2445,8 @@ func (s *session) recordOnTransactionExecution(err error, counter int, duration } } } + +// GetTxnWriteThroughputSLI implements the Context interface. +func (s *session) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &s.txn.writeSLI +} diff --git a/session/session_test.go b/session/session_test.go index 96670b909aa37..a32e9e05349ed 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/parser" "github.com/pingcap/parser/auth" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -3568,3 +3569,46 @@ func (s *testSessionSuite2) TestRetryCommitWithSet(c *C) { tk2.MustQuery("select * from t use index(k1)").Check(testkit.Rows("1 11 101")) tk2.MustQuery("select * from t where pk = '1'").Check(testkit.Rows("1 11 101")) } + +func (s *testSessionSerialSuite) TestParseWithParams(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + se := tk.Se + exec := se.(sqlexec.RestrictedSQLExecutor) + + // test compatibility with ExcuteInternal + origin := se.GetSessionVars().InRestrictedSQL + se.GetSessionVars().InRestrictedSQL = true + defer func() { + se.GetSessionVars().InRestrictedSQL = origin + }() + _, err := exec.ParseWithParams(context.Background(), "SELECT 4") + c.Assert(err, IsNil) + + // test charset attack + stmt, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + c.Assert(err, IsNil) + + var sb strings.Builder + ctx := format.NewRestoreCtx(format.RestoreStringDoubleQuotes, &sb) + err = stmt.Restore(ctx) + c.Assert(err, IsNil) + c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") + + // test invalid sql + _, err = exec.ParseWithParams(context.Background(), "SELECT") + c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") + + // test invalid arguments to escape + _, err = exec.ParseWithParams(context.Background(), "SELECT %?") + c.Assert(err, ErrorMatches, "missing arguments.*") + + // test noescape + stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") + c.Assert(err, IsNil) + + sb.Reset() + ctx = format.NewRestoreCtx(0, &sb) + err = stmt.Restore(ctx) + c.Assert(err, IsNil) + c.Assert(sb.String(), Equals, "SELECT 3") +} diff --git a/session/txn.go b/session/txn.go index b9ce2741f6d57..fca8d75ccaae5 100644 --- a/session/txn.go +++ b/session/txn.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tipb/go-binlog" "go.uber.org/zap" ) @@ -57,6 +58,8 @@ type TxnState struct { // If doNotCommit is not nil, Commit() will not commit the transaction. // doNotCommit flag may be set when StmtCommit fail. doNotCommit error + + writeSLI sli.TxnWriteThroughputSLI } func (st *TxnState) init() { diff --git a/session/utils.go b/session/utils.go new file mode 100644 index 0000000000000..67788d5d53a4f --- /dev/null +++ b/session/utils.go @@ -0,0 +1,213 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" +) + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash will escape string into the buffer, with backslash. +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, hack.Slice(v)) +} + +// EscapeSQL will escape input arguments into the sql string, doing necessary processing. +// It works like printf() in c, there are following format specifiers: +// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) +// 2. %%: output % +// 3. %n: for identifiers, for example ("use %n", db) +// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". +// It is still your responsibility to write safe SQL. +func EscapeSQL(sql string, args ...interface{}) (string, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return "", errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.Replace(v, "`", "``", -1)...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + case []string: + buf = append(buf, '(') + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + buf = append(buf, ')') + default: + return "", errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return string(buf), nil +} diff --git a/session/utils_test.go b/session/utils_test.go new file mode 100644 index 0000000000000..f7d754418c019 --- /dev/null +++ b/session/utils_test.go @@ -0,0 +1,387 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/hack" +) + +var _ = Suite(&testUtilsSuite{}) + +type testUtilsSuite struct{} + +func (s *testUtilsSuite) TestReserveBuffer(c *C) { + res0 := reserveBuffer(nil, 0) + c.Assert(res0, HasLen, 0) + + res1 := reserveBuffer(res0, 3) + c.Assert(res1, HasLen, 3) + res1[1] = 3 + + res2 := reserveBuffer(res1, 9) + c.Assert(res2, HasLen, 12) + c.Assert(cap(res2), Equals, 15) + c.Assert(res2[:3], DeepEquals, res1) +} + +func (s *testUtilsSuite) TestEscapeBackslash(c *C) { + type TestCase struct { + name string + input []byte + output []byte + } + tests := []TestCase{ + { + name: "normal", + input: []byte("hello"), + output: []byte("hello"), + }, + { + name: "0", + input: []byte("he\x00lo"), + output: []byte("he\\0lo"), + }, + { + name: "break line", + input: []byte("he\nlo"), + output: []byte("he\\nlo"), + }, + { + name: "carry", + input: []byte("he\rlo"), + output: []byte("he\\rlo"), + }, + { + name: "substitute", + input: []byte("he\x1alo"), + output: []byte("he\\Zlo"), + }, + { + name: "single quote", + input: []byte("he'lo"), + output: []byte("he\\'lo"), + }, + { + name: "double quote", + input: []byte("he\"lo"), + output: []byte("he\\\"lo"), + }, + { + name: "back slash", + input: []byte("he\\lo"), + output: []byte("he\\\\lo"), + }, + { + name: "double escape", + input: []byte("he\x00lo\""), + output: []byte("he\\0lo\\\""), + }, + { + name: "chinese", + input: []byte("中文?"), + output: []byte("中文?"), + }, + } + for _, t := range tests { + commentf := Commentf("%s", t.name) + c.Assert(escapeBytesBackslash(nil, t.input), DeepEquals, t.output, commentf) + c.Assert(escapeStringBackslash(nil, string(hack.String(t.input))), DeepEquals, t.output, commentf) + } +} + +func (s *testUtilsSuite) TestEscapeSQL(c *C) { + type TestCase struct { + name string + input string + params []interface{} + output string + err string + } + time2, err := time.Parse("2006-01-02 15:04:05", "2018-01-23 04:03:05") + c.Assert(err, IsNil) + tests := []TestCase{ + { + name: "normal 1", + input: "select * from 1", + params: []interface{}{}, + output: "select * from 1", + err: "", + }, + { + name: "normal 2", + input: "WHERE source != 'builtin'", + params: []interface{}{}, + output: "WHERE source != 'builtin'", + err: "", + }, + { + name: "discard extra arguments", + input: "select * from 1", + params: []interface{}{4, 5, "rt"}, + output: "select * from 1", + err: "", + }, + { + name: "%? missing arguments", + input: "select %? from %?", + params: []interface{}{4}, + err: "missing arguments.*", + }, + { + name: "nil", + input: "select %?", + params: []interface{}{nil}, + output: "select NULL", + err: "", + }, + { + name: "int", + input: "select %?", + params: []interface{}{int(3)}, + output: "select 3", + err: "", + }, + { + name: "int8", + input: "select %?", + params: []interface{}{int8(4)}, + output: "select 4", + err: "", + }, + { + name: "int16", + input: "select %?", + params: []interface{}{int16(5)}, + output: "select 5", + err: "", + }, + { + name: "int32", + input: "select %?", + params: []interface{}{int32(6)}, + output: "select 6", + err: "", + }, + { + name: "int64", + input: "select %?", + params: []interface{}{int64(7)}, + output: "select 7", + err: "", + }, + { + name: "uint", + input: "select %?", + params: []interface{}{uint(8)}, + output: "select 8", + err: "", + }, + { + name: "uint8", + input: "select %?", + params: []interface{}{uint8(9)}, + output: "select 9", + err: "", + }, + { + name: "uint16", + input: "select %?", + params: []interface{}{uint16(10)}, + output: "select 10", + err: "", + }, + { + name: "uint32", + input: "select %?", + params: []interface{}{uint32(11)}, + output: "select 11", + err: "", + }, + { + name: "uint64", + input: "select %?", + params: []interface{}{uint64(12)}, + output: "select 12", + err: "", + }, + { + name: "float32", + input: "select %?", + params: []interface{}{float32(0.13)}, + output: "select 0.13", + err: "", + }, + { + name: "float64", + input: "select %?", + params: []interface{}{float64(0.14)}, + output: "select 0.14", + err: "", + }, + { + name: "bool on", + input: "select %?", + params: []interface{}{true}, + output: "select 1", + err: "", + }, + { + name: "bool off", + input: "select %?", + params: []interface{}{false}, + output: "select 0", + err: "", + }, + { + name: "time 0", + input: "select %?", + params: []interface{}{time.Time{}}, + output: "select '0000-00-00'", + err: "", + }, + { + name: "time 1", + input: "select %?", + params: []interface{}{time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)}, + output: "select '2019-01-01 00:00:00'", + err: "", + }, + { + name: "time 2", + input: "select %?", + params: []interface{}{time2}, + output: "select '2018-01-23 04:03:05'", + err: "", + }, + { + name: "time 3", + input: "select %?", + params: []interface{}{time.Unix(0, 888888888)}, + output: "select '1970-01-01 08:00:00.888888'", + err: "", + }, + { + name: "empty byte slice1", + input: "select %?", + params: []interface{}{[]byte(nil)}, + output: "select NULL", + err: "", + }, + { + name: "empty byte slice2", + input: "select %?", + params: []interface{}{[]byte{}}, + output: "select _binary''", + err: "", + }, + { + name: "byte slice", + input: "select %?", + params: []interface{}{[]byte{2, 3}}, + output: "select _binary'\x02\x03'", + err: "", + }, + { + name: "string", + input: "select %?", + params: []interface{}{"33"}, + output: "select '33'", + }, + { + name: "string slice", + input: "select %?", + params: []interface{}{[]string{"33", "44"}}, + output: "select ('33','44')", + }, + { + name: "raw json", + input: "select %?", + params: []interface{}{json.RawMessage(`{"h": "hello"}`)}, + output: "select '{\\\"h\\\": \\\"hello\\\"}'", + }, + { + name: "unsupported args", + input: "select %?", + params: []interface{}{make(chan byte)}, + err: "unsupported 1-th argument.*", + }, + { + name: "mixed arguments", + input: "select %?, %?, %?", + params: []interface{}{"33", 44, time.Time{}}, + output: "select '33', 44, '0000-00-00'", + }, + { + name: "simple injection", + input: "select %?", + params: []interface{}{"0; drop database"}, + output: "select '0; drop database'", + }, + { + name: "identifier, wrong arg", + input: "use %n", + params: []interface{}{3}, + err: "expect a string identifier.*", + }, + { + name: "identifier", + input: "use %n", + params: []interface{}{"table`"}, + output: "use `table```", + err: "", + }, + { + name: "%n missing arguments", + input: "use %n", + params: []interface{}{}, + err: "missing arguments.*", + }, + { + name: "% escape", + input: "select * from t where val = '%%?'", + params: []interface{}{}, + output: "select * from t where val = '%?'", + err: "", + }, + { + name: "unknown specifier", + input: "%v", + params: []interface{}{}, + output: "%v", + err: "", + }, + { + name: "truncated specifier ", + input: "rv %", + params: []interface{}{}, + output: "rv %", + err: "", + }, + } + for _, t := range tests { + comment := Commentf("%s", t.name) + escaped, err := EscapeSQL(t.input, t.params...) + if t.err == "" { + c.Assert(err, IsNil, comment) + c.Assert(escaped, Equals, t.output, comment) + } else { + c.Assert(err, NotNil, comment) + c.Assert(err, ErrorMatches, t.err, comment) + } + } +} diff --git a/sessionctx/context.go b/sessionctx/context.go index 167fc3bd79803..21cedc5cc2a22 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tipb/go-binlog" ) @@ -103,6 +104,8 @@ type Context interface { HasLockedTables() bool // PrepareTSFuture uses to prepare timestamp by future. PrepareTSFuture(ctx context.Context) + // GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI. + GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI } type basicCtxType int diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index fbba707a9cc77..7832d717d7e35 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -511,12 +511,10 @@ func (sc *StatementContext) MergeExecDetails(details *execdetails.ExecDetails, c sc.mu.Lock() if details != nil { sc.mu.execDetails.CopTime += details.CopTime - sc.mu.execDetails.ProcessTime += details.ProcessTime - sc.mu.execDetails.WaitTime += details.WaitTime sc.mu.execDetails.BackoffTime += details.BackoffTime sc.mu.execDetails.RequestCount++ - sc.mu.execDetails.TotalKeys += details.TotalKeys - sc.mu.execDetails.ProcessedKeys += details.ProcessedKeys + sc.MergeScanDetail(details.ScanDetail) + sc.MergeTimeDetail(details.TimeDetail) sc.mu.allExecDetails = append(sc.mu.allExecDetails, details) } if commitDetails != nil { @@ -529,6 +527,24 @@ func (sc *StatementContext) MergeExecDetails(details *execdetails.ExecDetails, c sc.mu.Unlock() } +// MergeScanDetail merges scan details into self. +func (sc *StatementContext) MergeScanDetail(scanDetail *execdetails.ScanDetail) { + // Currently TiFlash cop task does not fill scanDetail, so need to skip it if scanDetail is nil + if scanDetail == nil { + return + } + if sc.mu.execDetails.ScanDetail == nil { + sc.mu.execDetails.ScanDetail = &execdetails.ScanDetail{} + } + sc.mu.execDetails.ScanDetail.Merge(scanDetail) +} + +// MergeTimeDetail merges time details into self. +func (sc *StatementContext) MergeTimeDetail(timeDetail execdetails.TimeDetail) { + sc.mu.execDetails.TimeDetail.ProcessTime += timeDetail.ProcessTime + sc.mu.execDetails.TimeDetail.WaitTime += timeDetail.WaitTime +} + // MergeLockKeysExecDetails merges lock keys execution details into self. func (sc *StatementContext) MergeLockKeysExecDetails(lockKeys *execdetails.LockKeysDetails) { sc.mu.Lock() @@ -615,21 +631,21 @@ func (sc *StatementContext) CopTasksDetails() *CopTasksDetails { if n == 0 { return d } - d.AvgProcessTime = sc.mu.execDetails.ProcessTime / time.Duration(n) - d.AvgWaitTime = sc.mu.execDetails.WaitTime / time.Duration(n) + d.AvgProcessTime = sc.mu.execDetails.TimeDetail.ProcessTime / time.Duration(n) + d.AvgWaitTime = sc.mu.execDetails.TimeDetail.WaitTime / time.Duration(n) sort.Slice(sc.mu.allExecDetails, func(i, j int) bool { - return sc.mu.allExecDetails[i].ProcessTime < sc.mu.allExecDetails[j].ProcessTime + return sc.mu.allExecDetails[i].TimeDetail.ProcessTime < sc.mu.allExecDetails[j].TimeDetail.ProcessTime }) - d.P90ProcessTime = sc.mu.allExecDetails[n*9/10].ProcessTime - d.MaxProcessTime = sc.mu.allExecDetails[n-1].ProcessTime + d.P90ProcessTime = sc.mu.allExecDetails[n*9/10].TimeDetail.ProcessTime + d.MaxProcessTime = sc.mu.allExecDetails[n-1].TimeDetail.ProcessTime d.MaxProcessAddress = sc.mu.allExecDetails[n-1].CalleeAddress sort.Slice(sc.mu.allExecDetails, func(i, j int) bool { - return sc.mu.allExecDetails[i].WaitTime < sc.mu.allExecDetails[j].WaitTime + return sc.mu.allExecDetails[i].TimeDetail.WaitTime < sc.mu.allExecDetails[j].TimeDetail.WaitTime }) - d.P90WaitTime = sc.mu.allExecDetails[n*9/10].WaitTime - d.MaxWaitTime = sc.mu.allExecDetails[n-1].WaitTime + d.P90WaitTime = sc.mu.allExecDetails[n*9/10].TimeDetail.WaitTime + d.MaxWaitTime = sc.mu.allExecDetails[n-1].TimeDetail.WaitTime d.MaxWaitAddress = sc.mu.allExecDetails[n-1].CalleeAddress // calculate backoff details diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index c786d5e10fcbf..cd0a51800a471 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -37,10 +37,12 @@ func (s *stmtctxSuit) TestCopTasksDetails(c *C) { for i := 0; i < 100; i++ { d := &execdetails.ExecDetails{ CalleeAddress: fmt.Sprintf("%v", i+1), - ProcessTime: time.Second * time.Duration(i+1), - WaitTime: time.Millisecond * time.Duration(i+1), BackoffSleep: make(map[string]time.Duration), BackoffTimes: make(map[string]int), + TimeDetail: execdetails.TimeDetail{ + ProcessTime: time.Second * time.Duration(i+1), + WaitTime: time.Millisecond * time.Duration(i+1), + }, } for _, backoff := range backoffs { d.BackoffSleep[backoff] = time.Millisecond * 100 * time.Duration(i+1) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 53e018bfd52e1..2ec0d56e965c8 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -588,6 +588,12 @@ type SessionVars struct { // RewritePhaseInfo records all information about the rewriting phase. RewritePhaseInfo + // DurationOptimization is the duration of optimizing a query. + DurationOptimization time.Duration + + // DurationWaitTS is the duration of waiting for a snapshot TS + DurationWaitTS time.Duration + // PrevStmt is used to store the previous executed statement in the current session. PrevStmt fmt.Stringer @@ -650,6 +656,10 @@ type SessionVars struct { // PrevFoundInPlanCache indicates whether the last statement was found in plan cache. PrevFoundInPlanCache bool + // FoundInBinding indicates whether the execution plan is matched with the hints in the binding. + FoundInBinding bool + // PrevFoundInBinding indicates whether the last execution plan is matched with the hints in the binding. + PrevFoundInBinding bool // SelectLimit limits the max counts of select statement's output SelectLimit uint64 @@ -754,6 +764,8 @@ func NewSessionVars() *SessionVars { WindowingUseHighPrecision: true, PrevFoundInPlanCache: DefTiDBFoundInPlanCache, FoundInPlanCache: DefTiDBFoundInPlanCache, + PrevFoundInBinding: DefTiDBFoundInBinding, + FoundInBinding: DefTiDBFoundInBinding, SelectLimit: math.MaxUint64, AllowAutoRandExplicitInsert: DefTiDBAllowAutoRandExplicitInsert, EnableAmendPessimisticTxn: DefTiDBEnableAmendPessimisticTxn, @@ -1364,6 +1376,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { config.GetGlobalConfig().CheckMb4ValueInUTF8 = TiDBOptOn(val) case TiDBFoundInPlanCache: s.FoundInPlanCache = TiDBOptOn(val) + case TiDBFoundInBinding: + s.FoundInBinding = TiDBOptOn(val) case SQLSelectLimit: result, err := strconv.ParseUint(val, 10, 64) if err != nil { @@ -1589,6 +1603,10 @@ const ( SlowLogCompileTimeStr = "Compile_time" // SlowLogRewriteTimeStr is the rewrite time. SlowLogRewriteTimeStr = "Rewrite_time" + // SlowLogOptimizeTimeStr is the optimization time. + SlowLogOptimizeTimeStr = "Optimize_time" + // SlowLogWaitTSTimeStr is the time of waiting TS. + SlowLogWaitTSTimeStr = "Wait_TS" // SlowLogPreprocSubQueriesStr is the number of pre-processed sub-queries. SlowLogPreprocSubQueriesStr = "Preproc_subqueries" // SlowLogPreProcSubQueryTimeStr is the total time of pre-processing sub-queries. @@ -1633,6 +1651,8 @@ const ( SlowLogPrepared = "Prepared" // SlowLogPlanFromCache is used to indicate whether this plan is from plan cache. SlowLogPlanFromCache = "Plan_from_cache" + // SlowLogPlanFromBinding is used to indicate whether this plan is matched with the hints in the binding. + SlowLogPlanFromBinding = "Plan_from_binding" // SlowLogHasMoreResults is used to indicate whether this sql has more following results. SlowLogHasMoreResults = "Has_more_results" // SlowLogSucc is used to indicate whether this sql execute successfully. @@ -1674,6 +1694,8 @@ type SlowQueryLogItems struct { TimeTotal time.Duration TimeParse time.Duration TimeCompile time.Duration + TimeOptimize time.Duration + TimeWaitTS time.Duration IndexNames string StatsInfos map[string]uint64 CopTasks *stmtctx.CopTasksDetails @@ -1683,6 +1705,7 @@ type SlowQueryLogItems struct { Succ bool Prepared bool PlanFromCache bool + PlanFromBinding bool HasMoreResults bool PrevStmt string Plan string @@ -1754,6 +1777,9 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { } buf.WriteString("\n") + writeSlowLogItem(&buf, SlowLogOptimizeTimeStr, strconv.FormatFloat(logItems.TimeOptimize.Seconds(), 'f', -1, 64)) + writeSlowLogItem(&buf, SlowLogWaitTSTimeStr, strconv.FormatFloat(logItems.TimeWaitTS.Seconds(), 'f', -1, 64)) + if execDetailStr := logItems.ExecDetail.String(); len(execDetailStr) > 0 { buf.WriteString(SlowLogRowPrefixStr + execDetailStr + "\n") } @@ -1847,6 +1873,7 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { writeSlowLogItem(&buf, SlowLogPrepared, strconv.FormatBool(logItems.Prepared)) writeSlowLogItem(&buf, SlowLogPlanFromCache, strconv.FormatBool(logItems.PlanFromCache)) + writeSlowLogItem(&buf, SlowLogPlanFromBinding, strconv.FormatBool(logItems.PlanFromBinding)) writeSlowLogItem(&buf, SlowLogHasMoreResults, strconv.FormatBool(logItems.HasMoreResults)) writeSlowLogItem(&buf, SlowLogKVTotal, strconv.FormatFloat(logItems.KVTotal.Seconds(), 'f', -1, 64)) writeSlowLogItem(&buf, SlowLogPDTotal, strconv.FormatFloat(logItems.PDTotal.Seconds(), 'f', -1, 64)) diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index b58ab81a6b4d9..7419c4d6d07f2 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -136,12 +136,16 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { txnTS := uint64(406649736972468225) costTime := time.Second execDetail := execdetails.ExecDetails{ - ProcessTime: time.Second * time.Duration(2), - WaitTime: time.Minute, - BackoffTime: time.Millisecond, - RequestCount: 2, - TotalKeys: 10000, - ProcessedKeys: 20001, + BackoffTime: time.Millisecond, + RequestCount: 2, + ScanDetail: &execdetails.ScanDetail{ + ProcessedKeys: 20001, + TotalKeys: 10000, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: time.Second * time.Duration(2), + WaitTime: time.Minute, + }, } statsInfos := make(map[string]uint64) statsInfos["t1"] = 0 @@ -183,6 +187,8 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # Parse_time: 0.00000001 # Compile_time: 0.00000001 # Rewrite_time: 0.000000003 Preproc_subqueries: 2 Preproc_subqueries_time: 0.000000002 +# Optimize_time: 0.00000001 +# Wait_TS: 0.000000003 # Process_time: 2 Wait_time: 60 Backoff_time: 0.001 Request_count: 2 Total_keys: 10000 Process_keys: 20001 # DB: test # Index_names: [t1:a,t2:b] @@ -199,6 +205,7 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # Disk_max: 6666 # Prepared: true # Plan_from_cache: true +# Plan_from_binding: true # Has_more_results: true # KV_total: 10 # PD_total: 11 @@ -214,6 +221,8 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { TimeTotal: costTime, TimeParse: time.Duration(10), TimeCompile: time.Duration(10), + TimeOptimize: time.Duration(10), + TimeWaitTS: time.Duration(3), IndexNames: "[t1:a,t2:b]", StatsInfos: statsInfos, CopTasks: copTasks, @@ -222,6 +231,7 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { DiskMax: diskMax, Prepared: true, PlanFromCache: true, + PlanFromBinding: true, HasMoreResults: true, KVTotal: 10 * time.Second, PDTotal: 11 * time.Second, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index cef0093d86eb8..5049c650fb48a 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -734,6 +734,7 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)}, {ScopeSession, TiDBFoundInPlanCache, BoolToIntStr(DefTiDBFoundInPlanCache)}, + {ScopeSession, TiDBFoundInBinding, BoolToIntStr(DefTiDBFoundInBinding)}, {ScopeSession, TiDBEnableCollectExecutionInfo, BoolToIntStr(DefTiDBEnableCollectExecutionInfo)}, {ScopeGlobal | ScopeSession, TiDBAllowAutoRandExplicitInsert, boolToOnOff(DefTiDBAllowAutoRandExplicitInsert)}, {ScopeGlobal | ScopeSession, TiDBSlowLogMasking, BoolToIntStr(DefTiDBRedactLog)}, diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index d81cdf775db46..b48920e185f8b 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -181,6 +181,9 @@ const ( // TiDBFoundInPlanCache indicates whether the last statement was found in plan cache TiDBFoundInPlanCache = "last_plan_from_cache" + // TiDBFoundInBinding indicates whether the last statement was matched with the hints in the binding. + TiDBFoundInBinding = "last_plan_from_binding" + // TiDBAllowAutoRandExplicitInsert indicates whether explicit insertion on auto_random column is allowed. TiDBAllowAutoRandExplicitInsert = "allow_auto_random_explicit_insert" ) @@ -525,6 +528,7 @@ const ( DefTiDBMetricSchemaStep = 60 // 60s DefTiDBMetricSchemaRangeDuration = 60 // 60s DefTiDBFoundInPlanCache = false + DefTiDBFoundInBinding = false DefTiDBEnableCollectExecutionInfo = true DefTiDBAllowAutoRandExplicitInsert = false DefTiDBRedactLog = false diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index f5fcabe795211..676dcc2fb455c 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -163,6 +163,8 @@ func GetSessionOnlySysVars(s *SessionVars, key string) (string, bool, error) { return CapturePlanBaseline.GetVal(), true, nil case TiDBFoundInPlanCache: return BoolToIntStr(s.PrevFoundInPlanCache), true, nil + case TiDBFoundInBinding: + return BoolToIntStr(s.PrevFoundInBinding), true, nil case TiDBEnableCollectExecutionInfo: return BoolToIntStr(config.GetGlobalConfig().EnableCollectExecutionInfo), true, nil } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index be47632f29636..536d490a57847 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -87,6 +87,7 @@ func (s *testVarsutilSuite) TestNewSessionVars(c *C) { c.Assert(vars.TiDBOptJoinReorderThreshold, Equals, DefTiDBOptJoinReorderThreshold) c.Assert(vars.EnableFastAnalyze, Equals, DefTiDBUseFastAnalyze) c.Assert(vars.FoundInPlanCache, Equals, DefTiDBFoundInPlanCache) + c.Assert(vars.FoundInBinding, Equals, DefTiDBFoundInBinding) c.Assert(vars.AllowAutoRandExplicitInsert, Equals, DefTiDBAllowAutoRandExplicitInsert) assertFieldsGreaterThanZero(c, reflect.ValueOf(vars.Concurrency)) @@ -456,6 +457,13 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { c.Assert(val, Equals, "0") c.Assert(v.systems[TiDBFoundInPlanCache], Equals, "1") + err = SetSessionSystemVar(v, TiDBFoundInBinding, types.NewStringDatum("1")) + c.Assert(err, IsNil) + val, err = GetSessionSystemVar(v, TiDBFoundInBinding) + c.Assert(err, IsNil) + c.Assert(val, Equals, "0") + c.Assert(v.systems[TiDBFoundInBinding], Equals, "1") + err = SetSessionSystemVar(v, "UnknownVariable", types.NewStringDatum("on")) c.Assert(err, ErrorMatches, ".*]Unknown system variable 'UnknownVariable'") } diff --git a/statistics/handle/bootstrap.go b/statistics/handle/bootstrap.go index 4bc8257537eda..bc87750400f2a 100644 --- a/statistics/handle/bootstrap.go +++ b/statistics/handle/bootstrap.go @@ -60,18 +60,16 @@ func (h *Handle) initStatsMeta4Chunk(is infoschema.InfoSchema, cache *statsCache func (h *Handle) initStatsMeta(is infoschema.InfoSchema) (statsCache, error) { sql := "select HIGH_PRIORITY version, table_id, modify_count, count from mysql.stats_meta" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return statsCache{}, errors.Trace(err) } + defer terror.Call(rc.Close) tables := statsCache{tables: make(map[int64]*statistics.Table)} - req := rc[0].NewChunk() + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return statsCache{}, errors.Trace(err) } @@ -147,17 +145,15 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache *stat func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -187,17 +183,15 @@ func (h *Handle) initStatsTopN4Chunk(cache *statsCache, iter *chunk.Iterator4Chu func (h *Handle) initStatsTopN(cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -257,17 +251,15 @@ func initStatsBuckets4Chunk(ctx sessionctx.Context, cache *statsCache, iter *chu func (h *Handle) initStatsBuckets(cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound from mysql.stats_buckets order by table_id, is_index, hist_id, bucket_id" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -300,13 +292,13 @@ func (h *Handle) initStatsBuckets(cache *statsCache) error { func (h *Handle) InitStats(is infoschema.InfoSchema) (err error) { h.mu.Lock() defer func() { - _, err1 := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "commit") + _, err1 := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") if err == nil && err1 != nil { err = err1 } h.mu.Unlock() }() - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "begin") + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") if err != nil { return err } diff --git a/statistics/handle/ddl.go b/statistics/handle/ddl.go index 127608edb7bb8..7a1770d592be0 100644 --- a/statistics/handle/ddl.go +++ b/statistics/handle/ddl.go @@ -15,7 +15,6 @@ package handle import ( "context" - "fmt" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -75,28 +74,34 @@ func (h *Handle) DDLEventCh() chan *util.Event { func (h *Handle) insertTableStats2KV(info *model.TableInfo, physicalID int64) (err error) { h.mu.Lock() defer h.mu.Unlock() + ctx := context.Background() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { return errors.Trace(err) } startTS := txn.StartTS() - sqls := make([]string, 0, 1+len(info.Columns)+len(info.Indices)) - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_meta (version, table_id) values(%d, %d)", startTS, physicalID)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_meta (version, table_id) values(%?, %?)", startTS, physicalID); err != nil { + return err + } for _, col := range info.Columns { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%d, 0, %d, 0, %d)", physicalID, col.ID, startTS)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%?, 0, %?, 0, %?)", physicalID, col.ID, startTS); err != nil { + return err + } } for _, idx := range info.Indices { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%d, 1, %d, 0, %d)", physicalID, idx.ID, startTS)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%?, 1, %?, 0, %?)", physicalID, idx.ID, startTS); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) + return nil } // insertColStats2KV insert a record to stats_histograms with distinct_count 1 and insert a bucket to stats_buckets with default value. @@ -105,13 +110,14 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) h.mu.Lock() defer h.mu.Unlock() + ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { @@ -119,24 +125,21 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) } startTS := txn.StartTS() // First of all, we update the version. - _, err = exec.Execute(context.Background(), fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %?", startTS, physicalID) if err != nil { return } - ctx := context.TODO() // If we didn't update anything by last SQL, it means the stats of this table does not exist. if h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() > 0 { // By this step we can get the count of this table, then we can sure the count and repeats of bucket. - var rs []sqlexec.RecordSet - rs, err = exec.Execute(ctx, fmt.Sprintf("select count from mysql.stats_meta where table_id = %d", physicalID)) - if len(rs) > 0 { - defer terror.Call(rs[0].Close) - } + var rs sqlexec.RecordSet + rs, err = exec.ExecuteInternal(ctx, "select count from mysql.stats_meta where table_id = %?", physicalID) if err != nil { return } - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil { return } @@ -146,21 +149,26 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) if err != nil { return } - sqls := make([]string, 0, 1) if value.IsNull() { // If the adding column has default value null, all the existing rows have null value on the newly added column. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, null_count) values (%d, %d, 0, %d, 0, %d)", startTS, physicalID, colInfo.ID, count)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, null_count) values (%?, %?, 0, %?, 0, %?)", startTS, physicalID, colInfo.ID, count); err != nil { + return err + } } else { // If this stats exists, we insert histogram meta first, the distinct_count will always be one. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, tot_col_size) values (%d, %d, 0, %d, 1, %d)", startTS, physicalID, colInfo.ID, int64(len(value.GetBytes()))*count)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, tot_col_size) values (%?, %?, 0, %?, 1, %?)", startTS, physicalID, colInfo.ID, int64(len(value.GetBytes()))*count); err != nil { + return err + } + value, err = value.ConvertTo(h.mu.ctx.GetSessionVars().StmtCtx, types.NewFieldType(mysql.TypeBlob)) if err != nil { return } // There must be only one bucket for this new column and the value is the default value. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_buckets (table_id, is_index, hist_id, bucket_id, repeats, count, lower_bound, upper_bound) values (%d, 0, %d, 0, %d, %d, X'%X', X'%X')", physicalID, colInfo.ID, count, count, value.GetBytes(), value.GetBytes())) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_buckets (table_id, is_index, hist_id, bucket_id, repeats, count, lower_bound, upper_bound) values (%?, 0, %?, 0, %?, %?, %?, %?)", physicalID, colInfo.ID, count, count, value.GetBytes(), value.GetBytes()); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) } return } @@ -168,20 +176,10 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) // finishTransaction will execute `commit` when error is nil, otherwise `rollback`. func finishTransaction(ctx context.Context, exec sqlexec.SQLExecutor, err error) error { if err == nil { - _, err = exec.Execute(ctx, "commit") + _, err = exec.ExecuteInternal(ctx, "commit") } else { - _, err1 := exec.Execute(ctx, "rollback") + _, err1 := exec.ExecuteInternal(ctx, "rollback") terror.Log(errors.Trace(err1)) } return errors.Trace(err) } - -func execSQLs(ctx context.Context, exec sqlexec.SQLExecutor, sqls []string) error { - for _, sql := range sqls { - _, err := exec.Execute(ctx, sql) - if err != nil { - return err - } - } - return nil -} diff --git a/statistics/handle/dump.go b/statistics/handle/dump.go index 16295569d76c0..33ebfdd92fb94 100644 --- a/statistics/handle/dump.go +++ b/statistics/handle/dump.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -63,9 +64,19 @@ func dumpJSONCol(hist *statistics.Histogram, CMSketch *statistics.CMSketch) *jso // DumpStatsToJSON dumps statistic to json. func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, historyStatsExec sqlexec.RestrictedSQLExecutor) (*JSONTable, error) { + var snapshot uint64 + if historyStatsExec != nil { + sctx := historyStatsExec.(sessionctx.Context) + snapshot = sctx.GetSessionVars().SnapshotTS + } + return h.DumpStatsToJSONBySnapshot(dbName, tableInfo, snapshot) +} + +// DumpStatsToJSONBySnapshot dumps statistic to json. +func (h *Handle) DumpStatsToJSONBySnapshot(dbName string, tableInfo *model.TableInfo, snapshot uint64) (*JSONTable, error) { pi := tableInfo.GetPartitionInfo() if pi == nil { - return h.tableStatsToJSON(dbName, tableInfo, tableInfo.ID, historyStatsExec) + return h.tableStatsToJSON(dbName, tableInfo, tableInfo.ID, snapshot) } jsonTbl := &JSONTable{ DatabaseName: dbName, @@ -73,7 +84,7 @@ func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, hist Partitions: make(map[string]*JSONTable, len(pi.Definitions)), } for _, def := range pi.Definitions { - tbl, err := h.tableStatsToJSON(dbName, tableInfo, def.ID, historyStatsExec) + tbl, err := h.tableStatsToJSON(dbName, tableInfo, def.ID, snapshot) if err != nil { return nil, errors.Trace(err) } @@ -85,12 +96,12 @@ func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, hist return jsonTbl, nil } -func (h *Handle) tableStatsToJSON(dbName string, tableInfo *model.TableInfo, physicalID int64, historyStatsExec sqlexec.RestrictedSQLExecutor) (*JSONTable, error) { - tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, true, historyStatsExec) +func (h *Handle) tableStatsToJSON(dbName string, tableInfo *model.TableInfo, physicalID int64, snapshot uint64) (*JSONTable, error) { + tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, true, snapshot) if err != nil || tbl == nil { return nil, err } - tbl.Version, tbl.ModifyCount, tbl.Count, err = h.statsMetaByTableIDFromStorage(physicalID, historyStatsExec) + tbl.Version, tbl.ModifyCount, tbl.Count, err = h.statsMetaByTableIDFromStorage(physicalID, snapshot) if err != nil { return nil, err } diff --git a/statistics/handle/gc.go b/statistics/handle/gc.go index 3264485e6b703..7bb4bcd6a426f 100644 --- a/statistics/handle/gc.go +++ b/statistics/handle/gc.go @@ -15,7 +15,6 @@ package handle import ( "context" - "fmt" "time" "github.com/cznic/mathutil" @@ -27,6 +26,7 @@ import ( // GCStats will garbage collect the useless stats info. For dropped tables, we will first update their version so that // other tidb could know that table is deleted. func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error { + ctx := context.Background() // To make sure that all the deleted tables' schema and stats info have been acknowledged to all tidb, // we only garbage collect version before 10 lease. lease := mathutil.MaxInt64(int64(h.Lease()), int64(ddlLease)) @@ -34,8 +34,8 @@ func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error if h.LastUpdateVersion() < offset { return nil } - sql := fmt.Sprintf("select table_id from mysql.stats_meta where version < %d", h.LastUpdateVersion()-offset) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + gcVer := h.LastUpdateVersion() - offset + rows, _, err := h.execRestrictedSQL(ctx, "select table_id from mysql.stats_meta where version < %?", gcVer) if err != nil { return errors.Trace(err) } @@ -48,17 +48,18 @@ func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error } func (h *Handle) gcTableStats(is infoschema.InfoSchema, physicalID int64) error { - sql := fmt.Sprintf("select is_index, hist_id from mysql.stats_histograms where table_id = %d", physicalID) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + rows, _, err := h.execRestrictedSQL(ctx, "select is_index, hist_id from mysql.stats_histograms where table_id = %?", physicalID) if err != nil { return errors.Trace(err) } // The table has already been deleted in stats and acknowledged to all tidb, // we can safely remove the meta info now. if len(rows) == 0 { - sql := fmt.Sprintf("delete from mysql.stats_meta where table_id = %d", physicalID) - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) - return errors.Trace(err) + _, _, err = h.execRestrictedSQL(ctx, "delete from mysql.stats_meta where table_id = %?", physicalID) + if err != nil { + return errors.Trace(err) + } } h.mu.Lock() tbl, ok := h.getTableByPhysicalID(is, physicalID) @@ -99,29 +100,37 @@ func (h *Handle) deleteHistStatsFromKV(physicalID int64, histID int64, isIndex i h.mu.Lock() defer h.mu.Unlock() + ctx := context.Background() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { return errors.Trace(err) } startTS := txn.StartTS() - sqls := make([]string, 0, 4) // First of all, we update the version. If this table doesn't exist, it won't have any problem. Because we cannot delete anything. - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %? ", startTS, physicalID); err != nil { + return err + } // delete histogram meta - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_histograms where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } // delete top n data - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } // delete all buckets - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) - return execSQLs(context.Background(), exec, sqls) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } + return nil } // DeleteTableStatsFromKV deletes table statistics from kv. @@ -129,7 +138,7 @@ func (h *Handle) DeleteTableStatsFromKV(physicalID int64) (err error) { h.mu.Lock() defer h.mu.Unlock() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(context.Background(), "begin") if err != nil { return errors.Trace(err) } @@ -140,13 +149,23 @@ func (h *Handle) DeleteTableStatsFromKV(physicalID int64) (err error) { if err != nil { return errors.Trace(err) } + ctx := context.Background() startTS := txn.StartTS() - sqls := make([]string, 0, 5) // We only update the version so that other tidb will know that this table is deleted. - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_histograms where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_feedback where table_id = %d", physicalID)) - return execSQLs(context.Background(), exec, sqls) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %? ", startTS, physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_histograms where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_feedback where table_id = %?", physicalID); err != nil { + return err + } + return nil } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 3c681aafb3804..9dce7a5c45a7c 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -20,12 +20,12 @@ import ( "sync/atomic" "time" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" - "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/sessionctx" @@ -68,7 +68,7 @@ type Handle struct { atomic.Value } - restrictedExec sqlexec.RestrictedSQLExecutor + pool sessionPool // ddlEventCh is a channel to notify a ddl operation has happened. // It is sent only by owner or the drop stats executor, and read by stats handle. @@ -83,6 +83,37 @@ type Handle struct { lease atomic2.Duration } +func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context.Context, sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { + se, err := h.pool.Get() + if err != nil { + return nil, nil, errors.Trace(err) + } + defer h.pool.Put(se) + + exec := se.(sqlexec.RestrictedSQLExecutor) + return fn(ctx, exec) +} + +func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := exec.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + return exec.ExecRestrictedStmt(ctx, stmt) + }) +} + +func (h *Handle) execRestrictedSQLWithSnapshot(ctx context.Context, sql string, snapshot uint64, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := exec.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + return exec.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) + }) +} + // Clear the statsCache, only for test. func (h *Handle) Clear() { h.mu.Lock() @@ -101,19 +132,21 @@ func (h *Handle) Clear() { h.mu.Unlock() } +type sessionPool interface { + Get() (pools.Resource, error) + Put(pools.Resource) +} + // NewHandle creates a Handle for update stats. -func NewHandle(ctx sessionctx.Context, lease time.Duration) *Handle { +func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool) *Handle { handle := &Handle{ ddlEventCh: make(chan *util.Event, 100), listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, globalMap: make(tableDeltaMap), feedback: statistics.NewQueryFeedbackMap(), + pool: pool, } handle.lease.Store(lease) - // It is safe to use it concurrently because the exec won't touch the ctx. - if exec, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { - handle.restrictedExec = exec - } handle.mu.ctx = ctx handle.mu.rateMap = make(errorRateDeltaMap) handle.statsCache.Store(statsCache{tables: make(map[int64]*statistics.Table)}) @@ -158,8 +191,8 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { } else { lastVersion = 0 } - sql := fmt.Sprintf("SELECT version, table_id, modify_count, count from mysql.stats_meta where version > %d order by version", lastVersion) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + rows, _, err := h.execRestrictedSQL(ctx, "SELECT version, table_id, modify_count, count from mysql.stats_meta where version > %? order by version", lastVersion) if err != nil { return errors.Trace(err) } @@ -181,7 +214,10 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { continue } tableInfo := table.Meta() - tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, false, nil) + if oldTbl, ok := oldCache.tables[physicalID]; ok && oldTbl.Version >= version && tableInfo.UpdateTS == oldTbl.TblInfoUpdateTS { + continue + } + tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, false, 0) // Error is not nil may mean that there are some ddl changes on this table, we will not update it. if err != nil { logutil.BgLogger().Error("[stats] error occurred when read table stats", zap.String("table", tableInfo.Name.O), zap.Error(err)) @@ -195,6 +231,7 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { tbl.Count = count tbl.ModifyCount = modifyCount tbl.Name = getFullTableName(is, tableInfo) + tbl.TblInfoUpdateTS = tableInfo.UpdateTS tables = append(tables, tbl) } h.updateStatsCache(oldCache.update(tables, deletedTableIDs, lastVersion)) @@ -281,7 +318,7 @@ func (sc statsCache) update(tables []*statistics.Table, deletedIDs []int64, newV // LoadNeededHistograms will load histograms for those needed columns. func (h *Handle) LoadNeededHistograms() (err error) { cols := statistics.HistogramNeededColumns.AllCols() - reader, err := h.getStatsReader(nil) + reader, err := h.getStatsReader(0) if err != nil { return err } @@ -355,13 +392,11 @@ func (h *Handle) FlushStats() { } func (h *Handle) cmSketchFromStorage(reader *statsReader, tblID int64, isIndex, histID int64) (_ *statistics.CMSketch, err error) { - selSQL := fmt.Sprintf("select cm_sketch from mysql.stats_histograms where table_id = %d and is_index = %d and hist_id = %d", tblID, isIndex, histID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil || len(rows) == 0 { return nil, err } - selSQL = fmt.Sprintf("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %d and is_index = %d and hist_id = %d", tblID, isIndex, histID) - topNRows, _, err := reader.read(selSQL) + topNRows, _, err := reader.read("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil { return nil, err } @@ -497,8 +532,8 @@ func (h *Handle) columnStatsFromStorage(reader *statsReader, row chunk.Row, tabl } // tableStatsFromStorage loads table stats info from storage. -func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID int64, loadAll bool, historyStatsExec sqlexec.RestrictedSQLExecutor) (_ *statistics.Table, err error) { - reader, err := h.getStatsReader(historyStatsExec) +func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID int64, loadAll bool, snapshot uint64) (_ *statistics.Table, err error) { + reader, err := h.getStatsReader(snapshot) if err != nil { return nil, err } @@ -511,7 +546,7 @@ func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID in table, ok := h.statsCache.Load().(statsCache).tables[physicalID] // If table stats is pseudo, we also need to copy it, since we will use the column stats when // the average error rate of it is small. - if !ok || historyStatsExec != nil { + if !ok || snapshot > 0 { histColl := statistics.HistColl{ PhysicalID: physicalID, HavePhysicalID: true, @@ -526,8 +561,7 @@ func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID in table = table.Copy() } table.Pseudo = false - selSQL := fmt.Sprintf("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %d", physicalID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", physicalID) // Check deleted table. if err != nil || len(rows) == 0 { return nil, nil @@ -551,7 +585,7 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } @@ -564,29 +598,40 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg } version := txn.StartTS() - sqls := make([]string, 0, 4) // If the count is less than 0, then we do not want to update the modify count and count. if count >= 0 { - sqls = append(sqls, fmt.Sprintf("replace into mysql.stats_meta (version, table_id, count) values (%d, %d, %d)", version, tableID, count)) + _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_meta (version, table_id, count) values (%?, %?, %?)", version, tableID, count) } else { - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d", version, tableID)) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %?", version, tableID) + } + if err != nil { + return err } data, err := statistics.EncodeCMSketchWithoutTopN(cms) if err != nil { - return + return err } // Delete outdated data - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d and is_index = %d and hist_id = %d", tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, hg.ID); err != nil { + return err + } for _, meta := range cms.TopN() { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_top_n (table_id, is_index, hist_id, value, count) values (%d, %d, %d, X'%X', %d)", tableID, isIndex, hg.ID, meta.Data, meta.Count)) + _, err = exec.ExecuteInternal(ctx, "insert into mysql.stats_top_n (table_id, is_index, hist_id, value, count) values (%?, %?, %?, %?, %?)", tableID, isIndex, hg.ID, meta.Data, meta.Count) + if err != nil { + return err + } } flag := 0 if isAnalyzed == 1 { flag = statistics.AnalyzeFlag } - sqls = append(sqls, fmt.Sprintf("replace into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, flag, correlation) values (%d, %d, %d, %d, %d, %d, X'%X', %d, %d, %d, %f)", - tableID, isIndex, hg.ID, hg.NDV, version, hg.NullCount, data, hg.TotColSize, statistics.CurStatsVersion, flag, hg.Correlation)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d", tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, flag, correlation) values (%?, %?, %?, %?, %?, %?, %?, %?, %?, %?, %?)", + tableID, isIndex, hg.ID, hg.NDV, version, hg.NullCount, data, hg.TotColSize, statistics.CurStatsVersion, flag, hg.Correlation); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, hg.ID); err != nil { + return err + } sc := h.mu.ctx.GetSessionVars().StmtCtx var lastAnalyzePos []byte for i := range hg.Buckets { @@ -607,12 +652,16 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg if err != nil { return } - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_buckets(table_id, is_index, hist_id, bucket_id, count, repeats, lower_bound, upper_bound) values(%d, %d, %d, %d, %d, %d, X'%X', X'%X')", tableID, isIndex, hg.ID, i, count, hg.Buckets[i].Repeat, lowerBound.GetBytes(), upperBound.GetBytes())) + if _, err = exec.ExecuteInternal(ctx, "insert into mysql.stats_buckets(table_id, is_index, hist_id, bucket_id, count, repeats, lower_bound, upper_bound) values(%?, %?, %?, %?, %?, %?, %?, %?)", tableID, isIndex, hg.ID, i, count, hg.Buckets[i].Repeat, lowerBound.GetBytes(), upperBound.GetBytes()); err != nil { + return err + } } if isAnalyzed == 1 && len(lastAnalyzePos) > 0 { - sqls = append(sqls, fmt.Sprintf("update mysql.stats_histograms set last_analyze_pos = X'%X' where table_id = %d and is_index = %d and hist_id = %d", lastAnalyzePos, tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_histograms set last_analyze_pos = %? where table_id = %? and is_index = %? and hist_id = %?", lastAnalyzePos, tableID, isIndex, hg.ID); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) + return } // SaveMetaToStorage will save stats_meta to storage. @@ -621,7 +670,7 @@ func (h *Handle) SaveMetaToStorage(tableID, count, modifyCount int64) (err error defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } @@ -632,16 +681,13 @@ func (h *Handle) SaveMetaToStorage(tableID, count, modifyCount int64) (err error if err != nil { return errors.Trace(err) } - var sql string version := txn.StartTS() - sql = fmt.Sprintf("replace into mysql.stats_meta (version, table_id, count, modify_count) values (%d, %d, %d, %d)", version, tableID, count, modifyCount) - _, err = exec.Execute(ctx, sql) - return + _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_meta (version, table_id, count, modify_count) values (%?, %?, %?, %?)", version, tableID, count, modifyCount) + return err } func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID int64, tp *types.FieldType, distinct int64, isIndex int, ver uint64, nullCount int64, totColSize int64, corr float64) (_ *statistics.Histogram, err error) { - selSQL := fmt.Sprintf("select count, repeats, lower_bound, upper_bound from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d order by bucket_id", tableID, isIndex, colID) - rows, fields, err := reader.read(selSQL) + rows, fields, err := reader.read("select count, repeats, lower_bound, upper_bound from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) if err != nil { return nil, errors.Trace(err) } @@ -677,8 +723,7 @@ func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID } func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID int64) (int64, error) { - selSQL := fmt.Sprintf("select sum(count) from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d", tableID, 0, colID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select sum(count) from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %?", tableID, 0, colID) if err != nil { return 0, errors.Trace(err) } @@ -688,13 +733,16 @@ func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID int6 return rows[0].GetMyDecimal(0).ToInt() } -func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, historyStatsExec sqlexec.RestrictedSQLExecutor) (version uint64, modifyCount, count int64, err error) { - selSQL := fmt.Sprintf("SELECT version, modify_count, count from mysql.stats_meta where table_id = %d order by version", tableID) +func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, snapshot uint64) (version uint64, modifyCount, count int64, err error) { + ctx := context.Background() var rows []chunk.Row - if historyStatsExec == nil { - rows, _, err = h.restrictedExec.ExecRestrictedSQL(selSQL) + if snapshot == 0 { + rows, _, err = h.execRestrictedSQL(ctx, "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) } else { - rows, _, err = historyStatsExec.ExecRestrictedSQLWithSnapshot(selSQL) + rows, _, err = h.execRestrictedSQLWithSnapshot(ctx, "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", snapshot, tableID) + if err != nil { + return 0, 0, 0, err + } } if err != nil || len(rows) == 0 { return @@ -708,49 +756,34 @@ func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, historyStatsExec s // statsReader is used for simplify code that needs to read system tables in different sqls // but requires the same transactions. type statsReader struct { - ctx sessionctx.Context - history sqlexec.RestrictedSQLExecutor + ctx sqlexec.RestrictedSQLExecutor + snapshot uint64 } -func (sr *statsReader) read(sql string) (rows []chunk.Row, fields []*ast.ResultField, err error) { - if sr.history != nil { - return sr.history.ExecRestrictedSQLWithSnapshot(sql) - } - rc, err := sr.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } +func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { + ctx := context.TODO() + stmt, err := sr.ctx.ParseWithParams(ctx, sql, args...) if err != nil { - return nil, nil, err + return nil, nil, errors.Trace(err) } - for { - req := rc[0].NewChunk() - err := rc[0].Next(context.TODO(), req) - if err != nil { - return nil, nil, err - } - if req.NumRows() == 0 { - break - } - for i := 0; i < req.NumRows(); i++ { - rows = append(rows, req.GetRow(i)) - } + if sr.snapshot > 0 { + return sr.ctx.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(sr.snapshot)) } - return rows, rc[0].Fields(), nil + return sr.ctx.ExecRestrictedStmt(ctx, stmt) } func (sr *statsReader) isHistory() bool { - return sr.history != nil + return sr.snapshot > 0 } -func (h *Handle) getStatsReader(history sqlexec.RestrictedSQLExecutor) (reader *statsReader, err error) { +func (h *Handle) getStatsReader(snapshot uint64) (reader *statsReader, err error) { failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, errors.New("gofail genStatsReader error")) } }) - if history != nil { - return &statsReader{history: history}, nil + if snapshot > 0 { + return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor), snapshot: snapshot}, nil } h.mu.Lock() defer func() { @@ -762,18 +795,18 @@ func (h *Handle) getStatsReader(history sqlexec.RestrictedSQLExecutor) (reader * } }() failpoint.Inject("mockGetStatsReaderPanic", nil) - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "begin") + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") if err != nil { return nil, err } - return &statsReader{ctx: h.mu.ctx}, nil + return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor)}, nil } func (h *Handle) releaseStatsReader(reader *statsReader) error { - if reader.history != nil { + if reader.snapshot > 0 { return nil } - _, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "commit") + _, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") h.mu.Unlock() return err } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index eefcc08dee748..c8f486efda38f 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -265,7 +265,7 @@ func (s *testStatsSuite) TestVersion(c *C) { tbl1, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) c.Assert(err, IsNil) tableInfo1 := tbl1.Meta() - h := handle.NewHandle(testKit.Se, time.Millisecond) + h := handle.NewHandle(testKit.Se, time.Millisecond, do.SysSessionPool()) unit := oracle.ComposeTS(1, 0) testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", 2*unit, tableInfo1.ID) @@ -499,7 +499,7 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "0") - c.Assert(result.Rows()[1][9], Equals, "0.828571") + c.Assert(result.Rows()[1][9], Equals, "0.8285714285714286") testKit.MustExec("truncate table t") result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() @@ -515,7 +515,7 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "0") - c.Assert(result.Rows()[1][9], Equals, "-0.942857") + c.Assert(result.Rows()[1][9], Equals, "-0.9428571428571428") testKit.MustExec("truncate table t") testKit.MustExec("insert into t values (1,1),(2,1),(3,1),(4,1),(5,1),(6,1),(7,1),(8,1),(9,1),(10,1),(11,1),(12,1),(13,1),(14,1),(15,1),(16,1),(17,1),(18,1),(19,1),(20,2),(21,2),(22,2),(23,2),(24,2),(25,2)") @@ -532,14 +532,14 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "1") - c.Assert(result.Rows()[1][9], Equals, "0.828571") + c.Assert(result.Rows()[1][9], Equals, "0.8285714285714286") testKit.MustExec("truncate table t") testKit.MustExec("insert into t values(1,1),(2,7),(3,12),(8,18),(4,20),(5,21)") testKit.MustExec("analyze table t") result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) - c.Assert(result.Rows()[0][9], Equals, "0.828571") + c.Assert(result.Rows()[0][9], Equals, "0.8285714285714286") c.Assert(result.Rows()[1][9], Equals, "1") testKit.MustExec("drop table t") @@ -555,3 +555,24 @@ func (s *testStatsSuite) TestCorrelation(c *C) { c.Assert(len(result.Rows()), Equals, 1) c.Assert(result.Rows()[0][9], Equals, "0") } + +func (s *testStatsSuite) TestStatsCacheUpdateSkip(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + do := s.do + h := do.StatsHandle() + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + testKit.MustExec("insert into t values(1, 2)") + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + testKit.MustExec("analyze table t") + is := do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + statsTbl1 := h.GetTableStats(tableInfo) + c.Assert(statsTbl1.Pseudo, IsFalse) + h.Update(is) + statsTbl2 := h.GetTableStats(tableInfo) + c.Assert(statsTbl1, Equals, statsTbl2) +} diff --git a/statistics/handle/update.go b/statistics/handle/update.go index affd2c0ca2457..7c56cef58d2d1 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -331,7 +331,7 @@ func (h *Handle) dumpTableStatCountToKV(id int64, delta variable.TableDelta) (up defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return false, errors.Trace(err) } @@ -344,13 +344,14 @@ func (h *Handle) dumpTableStatCountToKV(id int64, delta variable.TableDelta) (up return false, errors.Trace(err) } startTS := txn.StartTS() - var sql string if delta.Delta < 0 { - sql = fmt.Sprintf("update mysql.stats_meta set version = %d, count = count - %d, modify_count = modify_count + %d where table_id = %d and count >= %d", startTS, -delta.Delta, delta.Count, id, -delta.Delta) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %?, count = count - %?, modify_count = modify_count + %? where table_id = %? and count >= %?", startTS, -delta.Delta, delta.Count, id, -delta.Delta) } else { - sql = fmt.Sprintf("update mysql.stats_meta set version = %d, count = count + %d, modify_count = modify_count + %d where table_id = %d", startTS, delta.Delta, delta.Count, id) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %?, count = count + %?, modify_count = modify_count + %? where table_id = %?", startTS, delta.Delta, delta.Count, id) + } + if err != nil { + return false, errors.Trace(err) } - err = execSQLs(context.Background(), exec, []string{sql}) updated = h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() > 0 return } @@ -371,7 +372,7 @@ func (h *Handle) dumpTableStatColSizeToKV(id int64, delta variable.TableDelta) e } sql := fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, tot_col_size) "+ "values %s on duplicate key update tot_col_size = tot_col_size + values(tot_col_size)", strings.Join(values, ",")) - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + _, _, err := h.execRestrictedSQL(context.Background(), sql) return errors.Trace(err) } @@ -409,10 +410,9 @@ func (h *Handle) DumpFeedbackToKV(fb *statistics.QueryFeedback) error { if fb.Tp == statistics.IndexType { isIndex = 1 } - sql := fmt.Sprintf("insert into mysql.stats_feedback (table_id, hist_id, is_index, feedback) values "+ - "(%d, %d, %d, X'%X')", fb.PhysicalID, fb.Hist.ID, isIndex, vals) + const sql = "insert into mysql.stats_feedback (table_id, hist_id, is_index, feedback) values (%?, %?, %?, %?)" h.mu.Lock() - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, fb.PhysicalID, fb.Hist.ID, isIndex, vals) h.mu.Unlock() if err != nil { metrics.DumpFeedbackCounter.WithLabelValues(metrics.LblError).Inc() @@ -503,8 +503,8 @@ func (h *Handle) UpdateErrorRate(is infoschema.InfoSchema) { // HandleUpdateStats update the stats using feedback. func (h *Handle) HandleUpdateStats(is infoschema.InfoSchema) error { - sql := "SELECT distinct table_id from mysql.stats_feedback" - tables, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + tables, _, err := h.execRestrictedSQL(ctx, "SELECT distinct table_id from mysql.stats_feedback") if err != nil { return errors.Trace(err) } @@ -516,20 +516,18 @@ func (h *Handle) HandleUpdateStats(is infoschema.InfoSchema) error { // this func lets `defer` works normally, where `Close()` should be called before any return err = func() error { tbl := ptbl.GetInt64(0) - sql = fmt.Sprintf("select table_id, hist_id, is_index, feedback from mysql.stats_feedback where table_id=%d order by hist_id, is_index", tbl) - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + const sql = "select table_id, hist_id, is_index, feedback from mysql.stats_feedback where table_id=%? order by hist_id, is_index" + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, tbl) if err != nil { return errors.Trace(err) } + defer terror.Call(rc.Close) tableID, histID, isIndex := int64(-1), int64(-1), int64(-1) var rows []chunk.Row for { - req := rc[0].NewChunk() + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -622,8 +620,8 @@ func (h *Handle) deleteOutdatedFeedback(tableID, histID, isIndex int64) error { defer h.mu.Unlock() hasData := true for hasData { - sql := fmt.Sprintf("delete from mysql.stats_feedback where table_id = %d and hist_id = %d and is_index = %d limit 10000", tableID, histID, isIndex) - _, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + sql := "delete from mysql.stats_feedback where table_id = %? and hist_id = %? and is_index = %? limit 10000" + _, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, tableID, histID, isIndex) if err != nil { return errors.Trace(err) } @@ -671,11 +669,16 @@ func TableAnalyzed(tbl *statistics.Table) bool { // "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio" and the current time is // between `start` and `end`. func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRatio float64, start, end, now time.Time) (bool, string) { + // Tests if current time is within the time period. + if !timeutil.WithinDayTimePeriod(start, end, now) { + return false, "" + } + analyzed := TableAnalyzed(tbl) if !analyzed { t := time.Unix(0, oracle.ExtractPhysical(tbl.Version)*int64(time.Millisecond)) dur := time.Since(t) - return dur >= limit, fmt.Sprintf("table unanalyzed, time since last updated %vs", dur) + return dur >= limit, fmt.Sprintf("table unanalyzed, time since last updated %v", dur) } // Auto analyze is disabled. if autoAnalyzeRatio == 0 { @@ -685,14 +688,13 @@ func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRat if float64(tbl.ModifyCount)/float64(tbl.Count) <= autoAnalyzeRatio { return false, "" } - // Tests if current time is within the time period. - return timeutil.WithinDayTimePeriod(start, end, now), fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tbl.Count, autoAnalyzeRatio) + return true, fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tbl.Count, autoAnalyzeRatio) } func (h *Handle) getAutoAnalyzeParameters() map[string]string { - sql := fmt.Sprintf("select variable_name, variable_value from mysql.global_variables where variable_name in ('%s', '%s', '%s')", - variable.TiDBAutoAnalyzeRatio, variable.TiDBAutoAnalyzeStartTime, variable.TiDBAutoAnalyzeEndTime) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + sql := "select variable_name, variable_value from mysql.global_variables where variable_name in (%?, %?, %?)" + rows, _, err := h.execRestrictedSQL(ctx, sql, variable.TiDBAutoAnalyzeRatio, variable.TiDBAutoAnalyzeStartTime, variable.TiDBAutoAnalyzeEndTime) if err != nil { return map[string]string{} } @@ -727,66 +729,77 @@ func parseAnalyzePeriod(start, end string) (time.Time, time.Time, error) { } // HandleAutoAnalyze analyzes the newly created table or index. -func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) { +func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { dbs := is.AllSchemaNames() parameters := h.getAutoAnalyzeParameters() autoAnalyzeRatio := parseAutoAnalyzeRatio(parameters[variable.TiDBAutoAnalyzeRatio]) start, end, err := parseAnalyzePeriod(parameters[variable.TiDBAutoAnalyzeStartTime], parameters[variable.TiDBAutoAnalyzeEndTime]) if err != nil { logutil.BgLogger().Error("[stats] parse auto analyze period failed", zap.Error(err)) - return + return false } for _, db := range dbs { tbls := is.SchemaTables(model.NewCIStr(db)) for _, tbl := range tbls { tblInfo := tbl.Meta() pi := tblInfo.GetPartitionInfo() - tblName := "`" + db + "`.`" + tblInfo.Name.O + "`" if pi == nil { statsTbl := h.GetTableStats(tblInfo) - sql := fmt.Sprintf("analyze table %s", tblName) - analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) + sql := "analyze table %n.%n" + analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql, db, tblInfo.Name.O) if analyzed { - return + // analyze one table at a time to let it get the freshest parameters. + // others will be analyzed next round which is just 3s later. + return true } continue } for _, def := range pi.Definitions { - sql := fmt.Sprintf("analyze table %s partition `%s`", tblName, def.Name.O) + sql := "analyze table %n.%n partition %n" statsTbl := h.GetPartitionStats(tblInfo, def.ID) - analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) + analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql, db, tblInfo.Name.O, def.Name.O) if analyzed { - return + return true } continue } } } + return false } -func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics.Table, start, end time.Time, ratio float64, sql string) bool { +func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics.Table, start, end time.Time, ratio float64, sql string, params ...interface{}) bool { if statsTbl.Pseudo || statsTbl.Count < AutoAnalyzeMinCnt { return false } if needAnalyze, reason := NeedAnalyzeTable(statsTbl, 20*h.Lease(), ratio, start, end, time.Now()); needAnalyze { - logutil.BgLogger().Info("[stats] auto analyze triggered", zap.String("sql", sql), zap.String("reason", reason)) - h.execAutoAnalyze(sql) + escaped, err := sqlexec.EscapeSQL(sql, params...) + if err != nil { + return false + } + logutil.BgLogger().Info("[stats] auto analyze triggered", zap.String("sql", escaped), zap.String("reason", reason)) + h.execAutoAnalyze(sql, params...) return true } for _, idx := range tblInfo.Indices { if _, ok := statsTbl.Indices[idx.ID]; !ok && idx.State == model.StatePublic { - sql = fmt.Sprintf("%s index `%s`", sql, idx.Name.O) - logutil.BgLogger().Info("[stats] auto analyze for unanalyzed", zap.String("sql", sql)) - h.execAutoAnalyze(sql) + sqlWithIdx := sql + " index %n" + paramsWithIdx := append(params, idx.Name.O) + escaped, err := sqlexec.EscapeSQL(sqlWithIdx, paramsWithIdx...) + if err != nil { + return false + } + logutil.BgLogger().Info("[stats] auto analyze for unanalyzed", zap.String("sql", escaped)) + h.execAutoAnalyze(sqlWithIdx, paramsWithIdx...) return true } } return false } -func (h *Handle) execAutoAnalyze(sql string) { +func (h *Handle) execAutoAnalyze(sql string, params ...interface{}) { startTime := time.Now() - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + _, _, err := h.execRestrictedSQL(context.Background(), sql, params...) dur := time.Since(startTime) metrics.AutoAnalyzeHistogram.Observe(dur.Seconds()) if err != nil { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index 06196cb9f226f..712dc77cf3ca4 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -41,6 +41,26 @@ import ( ) var _ = Suite(&testStatsSuite{}) +var _ = SerialSuites(&testSerialStatsSuite{}) + +type testSerialStatsSuite struct { + store kv.Storage + do *domain.Domain +} + +func (s *testSerialStatsSuite) SetUpSuite(c *C) { + testleak.BeforeTest() + // Add the hook here to avoid data race. + var err error + s.store, s.do, err = newStoreWithBootstrap() + c.Assert(err, IsNil) +} + +func (s *testSerialStatsSuite) TearDownSuite(c *C) { + s.do.Close() + s.store.Close() + testleak.AfterTest(c)() +} type testStatsSuite struct { store kv.Storage @@ -465,6 +485,41 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { c.Assert(hg.Len(), Equals, 3) } +func (s *testSerialStatsSuite) TestAutoAnalyzeOnEmptyTable(c *C) { + defer cleanEnv(c, s.store, s.do) + tk := testkit.NewTestKit(c, s.store) + + oriStart := tk.MustQuery("select @@tidb_auto_analyze_start_time").Rows()[0][0].(string) + oriEnd := tk.MustQuery("select @@tidb_auto_analyze_end_time").Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='%v'", oriStart)) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='%v'", oriEnd)) + }() + + t := time.Now().Add(-1 * time.Minute) + h, m := t.Hour(), t.Minute() + start, end := fmt.Sprintf("%02d:%02d +0000", h, m), fmt.Sprintf("%02d:%02d +0000", h, m) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='%v'", start)) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='%v'", end)) + s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()) + + tk.MustExec("use test") + tk.MustExec("create table t (a int, index idx(a))") + // to pass the stats.Pseudo check in autoAnalyzeTable + tk.MustExec("analyze table t") + // to pass the AutoAnalyzeMinCnt check in autoAnalyzeTable + tk.MustExec("insert into t values (1)" + strings.Repeat(", (1)", int(handle.AutoAnalyzeMinCnt))) + c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(s.do.StatsHandle().Update(s.do.InfoSchema()), IsNil) + + // test if it will be limited by the time range + c.Assert(s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()), IsFalse) + + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='00:00 +0000'")) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='23:59 +0000'")) + c.Assert(s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()), IsTrue) +} + func (s *testStatsSuite) TestAutoUpdatePartition(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) diff --git a/statistics/histogram.go b/statistics/histogram.go index 66e8f9f3f3677..76cbfe1732d07 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -22,6 +22,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -89,6 +90,13 @@ type scalar struct { // NewHistogram creates a new histogram. func NewHistogram(id, ndv, nullCount int64, version uint64, tp *types.FieldType, bucketSize int, totColSize int64) *Histogram { + if tp.EvalType() == types.ETString { + // The histogram will store the string value's 'sort key' representation of its collation. + // If we directly set the field type's collation to its original one. We would decode the Key representation using its collation. + // This would cause panic. So we apply a little trick here to avoid decoding it by explicitly changing the collation to 'CollationBin'. + tp = tp.Clone() + tp.Collate = charset.CollationBin + } return &Histogram{ ID: id, NDV: ndv, diff --git a/statistics/table.go b/statistics/table.go index 518da999b6695..59592791af12e 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -61,6 +61,12 @@ type Table struct { HistColl Version uint64 Name string + // TblInfoUpdateTS is the UpdateTS of the TableInfo used when filling this struct. + // It is the schema version of the corresponding table. It is used to skip redundant + // loading of stats, i.e, if the cached stats is already update-to-date with mysql.stats_xxx tables, + // and the schema of the table does not change, we don't need to load the stats for this + // table again. + TblInfoUpdateTS uint64 } // HistColl is a collection of histogram. It collects enough information for plan to calculate the selectivity. @@ -99,9 +105,10 @@ func (t *Table) Copy() *Table { newHistColl.Indices[id] = idx } nt := &Table{ - HistColl: newHistColl, - Version: t.Version, - Name: t.Name, + HistColl: newHistColl, + Version: t.Version, + Name: t.Name, + TblInfoUpdateTS: t.TblInfoUpdateTS, } return nt } diff --git a/store/mockstore/mocktikv/cluster.go b/store/mockstore/mocktikv/cluster.go index cfc1f09e5405a..8eeb9f676761e 100644 --- a/store/mockstore/mocktikv/cluster.go +++ b/store/mockstore/mocktikv/cluster.go @@ -178,19 +178,20 @@ func (c *Cluster) GetStoreByAddr(addr string) *metapb.Store { } // GetAndCheckStoreByAddr checks and returns a Store's meta by an addr -func (c *Cluster) GetAndCheckStoreByAddr(addr string) (*metapb.Store, error) { +func (c *Cluster) GetAndCheckStoreByAddr(addr string) (ss []*metapb.Store, err error) { c.RLock() defer c.RUnlock() for _, s := range c.stores { if s.cancel { - return nil, context.Canceled + err = context.Canceled + return } if s.meta.GetAddress() == addr { - return proto.Clone(s.meta).(*metapb.Store), nil + ss = append(ss, proto.Clone(s.meta).(*metapb.Store)) } } - return nil, nil + return } // AddStore add a new Store to the cluster. @@ -209,6 +210,15 @@ func (c *Cluster) RemoveStore(storeID uint64) { delete(c.stores, storeID) } +// MarkTombstone marks store as tombstone. +func (c *Cluster) MarkTombstone(storeID uint64) { + c.Lock() + defer c.Unlock() + nm := *c.stores[storeID].meta + nm.State = metapb.StoreState_Tombstone + c.stores[storeID].meta = &nm +} + // UpdateStoreAddr updates store address for cluster. func (c *Cluster) UpdateStoreAddr(storeID uint64, addr string, labels ...*metapb.StoreLabel) { c.Lock() diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go index 8d0fb5ce39c98..290798cd087ad 100644 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -740,18 +740,20 @@ func NewRPCClient(cluster *Cluster, mvccStore MVCCStore) *RPCClient { } func (c *RPCClient) getAndCheckStoreByAddr(addr string) (*metapb.Store, error) { - store, err := c.Cluster.GetAndCheckStoreByAddr(addr) + stores, err := c.Cluster.GetAndCheckStoreByAddr(addr) if err != nil { return nil, err } - if store == nil { + if len(stores) == 0 { return nil, errors.New("connect fail") } - if store.GetState() == metapb.StoreState_Offline || - store.GetState() == metapb.StoreState_Tombstone { - return nil, errors.New("connection refused") + for _, store := range stores { + if store.GetState() != metapb.StoreState_Offline && + store.GetState() != metapb.StoreState_Tombstone { + return store, nil + } } - return store, nil + return nil, errors.New("connection refused") } func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, error) { diff --git a/store/tikv/batch_coprocessor.go b/store/tikv/batch_coprocessor.go index 56b1b1fed9f3e..dfa1b58b87eea 100644 --- a/store/tikv/batch_coprocessor.go +++ b/store/tikv/batch_coprocessor.go @@ -350,8 +350,8 @@ func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *Backoffer, ta IsolationLevel: pbIsolationLevel(b.req.IsolationLevel), Priority: kvPriorityToCommandPri(b.req.Priority), NotFillCache: b.req.NotFillCache, - HandleTime: true, - ScanDetail: true, + RecordTimeStat: true, + RecordScanStat: true, TaskId: b.req.TaskID, }) req.StoreTp = kv.TiFlash @@ -387,7 +387,8 @@ func (b *batchCopIterator) handleStreamedBatchCopResponse(ctx context.Context, b return nil } - if err1 := bo.Backoff(boTiKVRPC, errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { + // Currently this function is only used in TiFlash batch cop. + if err1 := bo.Backoff(boTiFlashRPC, errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { return errors.Trace(err) } diff --git a/store/tikv/coprocessor.go b/store/tikv/coprocessor.go index 9022e877d53d6..2846159e89331 100644 --- a/store/tikv/coprocessor.go +++ b/store/tikv/coprocessor.go @@ -872,8 +872,8 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch IsolationLevel: pbIsolationLevel(worker.req.IsolationLevel), Priority: kvPriorityToCommandPri(worker.req.Priority), NotFillCache: worker.req.NotFillCache, - HandleTime: true, - ScanDetail: true, + RecordTimeStat: true, + RecordScanStat: true, TaskId: worker.req.TaskID, }) req.StoreTp = task.storeType @@ -1013,9 +1013,9 @@ func (worker *copIteratorWorker) logTimeCopTask(costTime time.Duration, task *co } } - if detail != nil && detail.HandleTime != nil { - processMs := detail.HandleTime.ProcessMs - waitMs := detail.HandleTime.WaitMs + if detail != nil && detail.TimeDetail != nil { + processMs := detail.TimeDetail.ProcessWallTimeMs + waitMs := detail.TimeDetail.WaitWallTimeMs if processMs > minLogKVProcessTime { logStr += fmt.Sprintf(" kv_process_ms:%d", processMs) if detail.ScanDetail != nil { @@ -1149,18 +1149,29 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon resp.detail.CalleeAddress = rpcCtx.Addr } resp.respTime = costTime - if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { - if handleTime := pbDetails.HandleTime; handleTime != nil { - resp.detail.WaitTime = time.Duration(handleTime.WaitMs) * time.Millisecond - resp.detail.ProcessTime = time.Duration(handleTime.ProcessMs) * time.Millisecond + sd := &execdetails.ScanDetail{} + td := execdetails.TimeDetail{} + if pbDetails := resp.pbResp.ExecDetailsV2; pbDetails != nil { + // Take values in `ExecDetailsV2` first. + if timeDetail := pbDetails.TimeDetail; timeDetail != nil { + td.MergeFromTimeDetail(timeDetail) + } + if scanDetailV2 := pbDetails.ScanDetailV2; scanDetailV2 != nil { + sd.MergeFromScanDetailV2(scanDetailV2) + } + } else if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { + if timeDetail := pbDetails.TimeDetail; timeDetail != nil { + td.MergeFromTimeDetail(timeDetail) } if scanDetail := pbDetails.ScanDetail; scanDetail != nil { if scanDetail.Write != nil { - resp.detail.TotalKeys += scanDetail.Write.Total - resp.detail.ProcessedKeys += scanDetail.Write.Processed + sd.ProcessedKeys = scanDetail.Write.Processed + sd.TotalKeys = scanDetail.Write.Total } } } + resp.detail.ScanDetail = sd + resp.detail.TimeDetail = td if resp.pbResp.IsCacheHit { if cacheValue == nil { return nil, errors.New("Internal error: received illegal TiKV response") @@ -1173,7 +1184,7 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon } else { // Cache not hit or cache hit but not valid: update the cache if the response can be cached. if cacheKey != nil && resp.pbResp.CanBeCached && resp.pbResp.CacheLastVersion > 0 { - if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.ProcessTime) { + if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.TimeDetail.ProcessTime) { data := make([]byte, len(resp.pbResp.Data)) copy(data, resp.pbResp.Data) diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 0ceea6292019b..41c7427779db5 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -285,7 +285,7 @@ func (w *GCWorker) prepare() (bool, uint64, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, 0, errors.Trace(err) } @@ -1599,7 +1599,7 @@ func (w *GCWorker) checkLeader() (bool, error) { defer se.Close() ctx := context.Background() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1624,7 +1624,7 @@ func (w *GCWorker) checkLeader() (bool, error) { se.RollbackTxn(ctx) - _, err = se.Execute(ctx, "BEGIN") + _, err = se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1732,16 +1732,13 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - stmt := fmt.Sprintf(`SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name='%s' FOR UPDATE`, key) - rs, err := se.Execute(ctx, stmt) - if len(rs) > 0 { - defer terror.Call(rs[0].Close) - } + rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) if err != nil { return "", errors.Trace(err) } - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil { return "", errors.Trace(err) } @@ -1758,13 +1755,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { } func (w *GCWorker) saveValueToSysTable(key, value string) error { - stmt := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'`, - key, value, gcVariableComments[key]) + UPDATE variable_value = %?, comment = %?` se := createSession(w.store) defer se.Close() - _, err := se.Execute(context.Background(), stmt) + _, err := se.ExecuteInternal(context.Background(), stmt, + key, value, gcVariableComments[key], + value, gcVariableComments[key]) logutil.BgLogger().Debug("[gc worker] save kv", zap.String("key", key), zap.String("value", value), diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 456a56ae347b7..e629d81a5a0a8 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -1522,7 +1522,7 @@ func (s *Store) reResolve(c *RegionCache) { // we cannot do backoff in reResolve loop but try check other store and wait tick. return } - if store == nil { + if store == nil || store.State == metapb.StoreState_Tombstone { // store has be removed in PD, we should invalidate all regions using those store. logutil.BgLogger().Info("invalidate regions in removed store", zap.Uint64("store", s.storeID), zap.String("add", s.addr)) diff --git a/store/tikv/region_cache_test.go b/store/tikv/region_cache_test.go index c2f2d6e9d7cf3..b2dc597524b41 100644 --- a/store/tikv/region_cache_test.go +++ b/store/tikv/region_cache_test.go @@ -833,6 +833,33 @@ func (s *testRegionCacheSuite) TestReplaceNewAddrAndOldOfflineImmediately(c *C) c.Assert(getVal, BytesEquals, testValue) } +func (s *testRegionCacheSuite) TestReplaceStore(c *C) { + mvccStore := mocktikv.MustNewMVCCStore() + defer mvccStore.Close() + + client := &RawKVClient{ + clusterID: 0, + regionCache: NewRegionCache(mocktikv.NewPDClient(s.cluster)), + rpcClient: mocktikv.NewRPCClient(s.cluster, mvccStore), + } + defer client.Close() + testKey := []byte("test_key") + testValue := []byte("test_value") + err := client.Put(testKey, testValue) + c.Assert(err, IsNil) + + s.cluster.MarkTombstone(s.store1) + store3 := s.cluster.AllocID() + peer3 := s.cluster.AllocID() + s.cluster.AddStore(store3, s.storeAddr(s.store1)) + s.cluster.AddPeer(s.region1, store3, peer3) + s.cluster.RemovePeer(s.region1, s.peer1) + s.cluster.ChangeLeader(s.region1, peer3) + + err = client.Put(testKey, testValue) + c.Assert(err, IsNil) +} + func (s *testRegionCacheSuite) TestListRegionIDsInCache(c *C) { // ['' - 'm' - 'z'] region2 := s.cluster.AllocID() diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index f1523456427b2..bcb071da5022e 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -199,7 +199,10 @@ func (ss *RegionBatchRequestSender) onSendFail(bo *Backoffer, ctxs []copTaskAndR // When a store is not available, the leader of related region should be elected quickly. // TODO: the number of retry time should be limited:since region may be unavailable // when some unrecoverable disaster happened. - err = bo.Backoff(boTiKVRPC, errors.Errorf("send tikv request error: %v, ctxs: %v, try next peer later", err, ctxs)) + + // Currently this function is only used in TiFlash batch cop. + // TODO: need code refactoring for these functions. + err = bo.Backoff(boTiFlashRPC, errors.Errorf("send tiflash request error: %v, ctxs: %v, try next peer later", err, ctxs)) return errors.Trace(err) } @@ -604,7 +607,7 @@ func (s *RegionRequestSender) onRegionError(bo *Backoffer, ctx *RPCContext, seed if storeNotMatch := regionErr.GetStoreNotMatch(); storeNotMatch != nil { // store not match - logutil.BgLogger().Warn("tikv reports `StoreNotMatch` retry later", + logutil.BgLogger().Debug("tikv reports `StoreNotMatch` retry later", zap.Stringer("storeNotMatch", storeNotMatch), zap.Stringer("ctx", ctx)) ctx.Store.markNeedCheck(s.regionCache.notifyCheckCh) diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index d735527919088..d6b798970665e 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -324,6 +324,9 @@ func (s *tikvSnapshot) batchGetSingleRegion(bo *Backoffer, batch batchKeys, coll lockedKeys = append(lockedKeys, lock.Key) locks = append(locks, lock) } + if batchGetResp.ExecDetailsV2 != nil { + s.mergeExecDetail(batchGetResp.ExecDetailsV2) + } if len(lockedKeys) > 0 { msBeforeExpired, err := cli.ResolveLocks(bo, s.version.Ver, locks) if err != nil { @@ -436,6 +439,9 @@ func (s *tikvSnapshot) get(ctx context.Context, bo *Backoffer, k kv.Key) ([]byte return nil, errors.Trace(ErrBodyMissing) } cmdGetResp := resp.Resp.(*pb.GetResponse) + if cmdGetResp.ExecDetailsV2 != nil { + s.mergeExecDetail(cmdGetResp.ExecDetailsV2) + } val := cmdGetResp.GetValue() if keyErr := cmdGetResp.GetError(); keyErr != nil { lock, err := extractLockFromKeyErr(keyErr) @@ -458,6 +464,22 @@ func (s *tikvSnapshot) get(ctx context.Context, bo *Backoffer, k kv.Key) ([]byte } } +func (s *tikvSnapshot) mergeExecDetail(detail *pb.ExecDetailsV2) { + s.mu.Lock() + defer s.mu.Unlock() + if detail == nil || s.mu.stats == nil { + return + } + if s.mu.stats.scanDetail == nil { + s.mu.stats.scanDetail = &execdetails.ScanDetail{} + } + if s.mu.stats.timeDetail == nil { + s.mu.stats.timeDetail = &execdetails.TimeDetail{} + } + s.mu.stats.scanDetail.MergeFromScanDetailV2(detail.ScanDetailV2) + s.mu.stats.timeDetail.MergeFromTimeDetail(detail.TimeDetail) +} + // Iter return a list of key-value pair after `k`. func (s *tikvSnapshot) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) { scanner, err := newScanner(s, k, upperBound, scanBatchSize, false) @@ -659,6 +681,8 @@ type SnapshotRuntimeStats struct { rpcStats RegionRequestRuntimeStats backoffSleepMS map[backoffType]int backoffTimes map[backoffType]int + scanDetail *execdetails.ScanDetail + timeDetail *execdetails.TimeDetail } // Tp implements the RuntimeStats interface. @@ -727,5 +751,15 @@ func (rs *SnapshotRuntimeStats) String() string { d := time.Duration(ms) * time.Millisecond buf.WriteString(fmt.Sprintf("%s_backoff:{num:%d, total_time:%s}", k.String(), v, execdetails.FormatDuration(d))) } + timeDetail := rs.timeDetail.String() + if timeDetail != "" { + buf.WriteString(", ") + buf.WriteString(timeDetail) + } + scanDetail := rs.scanDetail.String() + if scanDetail != "" { + buf.WriteString(", ") + buf.WriteString(scanDetail) + } return buf.String() } diff --git a/store/tikv/snapshot_test.go b/store/tikv/snapshot_test.go index 87e883c7f595a..a3b51d8b86ab5 100644 --- a/store/tikv/snapshot_test.go +++ b/store/tikv/snapshot_test.go @@ -317,4 +317,26 @@ func (s *testSnapshotSuite) TestSnapshotRuntimeStats(c *C) { snapshot.recordBackoffInfo(bo) expect := "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}" c.Assert(snapshot.mu.stats.String(), Equals, expect) + detail := &pb.ExecDetailsV2{ + TimeDetail: &pb.TimeDetail{ + WaitWallTimeMs: 100, + ProcessWallTimeMs: 100, + }, + ScanDetailV2: &pb.ScanDetailV2{ + ProcessedVersions: 10, + TotalVersions: 15, + }, + } + snapshot.mergeExecDetail(detail) + expect = "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}, " + + "total_process_time: 100ms, total_wait_time: 100ms, " + + "scan_detail: {total_process_keys: 10, " + + "total_keys: 15}" + c.Assert(snapshot.mu.stats.String(), Equals, expect) + snapshot.mergeExecDetail(detail) + expect = "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}, " + + "total_process_time: 200ms, total_wait_time: 200ms, " + + "scan_detail: {total_process_keys: 20, " + + "total_keys: 30}" + c.Assert(snapshot.mu.stats.String(), Equals, expect) } diff --git a/store/tikv/split_region.go b/store/tikv/split_region.go index a9ce30affd95a..6a22945603032 100644 --- a/store/tikv/split_region.go +++ b/store/tikv/split_region.go @@ -33,7 +33,7 @@ import ( "go.uber.org/zap" ) -const splitBatchRegionLimit = 16 +const splitBatchRegionLimit = 2048 func equalRegionStartKey(key, regionStartKey []byte) bool { return bytes.Equal(key, regionStartKey) diff --git a/table/column.go b/table/column.go index d3fa703266285..72fedda4fd796 100644 --- a/table/column.go +++ b/table/column.go @@ -248,11 +248,11 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty // Set it to true only in FillVirtualColumnValue and UnionScanExec.Next() // If the handle of err is changed latter, the behavior of forceIgnoreTruncate also need to change. // TODO: change the third arg to TypeField. Not pass ColumnInfo. -func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnOverflow, forceIgnoreTruncate bool) (casted types.Datum, err error) { +func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnErr, forceIgnoreTruncate bool) (casted types.Datum, err error) { sc := ctx.GetSessionVars().StmtCtx casted, err = val.ConvertTo(sc, &col.FieldType) // TODO: make sure all truncate errors are handled by ConvertTo. - if types.ErrOverflow.Equal(err) && returnOverflow { + if returnErr && err != nil { return casted, err } if err != nil && types.ErrTruncated.Equal(err) && col.Tp != mysql.TypeSet && col.Tp != mysql.TypeEnum { @@ -604,10 +604,8 @@ func GetZeroValue(col *model.ColumnInfo) types.Datum { } else { d.SetString("", col.Collate) } - case mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: d.SetString("", col.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - d.SetBytes([]byte{}) case mysql.TypeDuration: d.SetMysqlDuration(types.ZeroDuration) case mysql.TypeDate: diff --git a/table/column_test.go b/table/column_test.go index feffb801ebf6a..826c9a8a4f11b 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -183,7 +183,7 @@ func (t *testTableSuite) TestGetZeroValue(c *C) { }, { types.NewFieldType(mysql.TypeBlob), - types.NewBytesDatum([]byte{}), + types.NewStringDatum(""), }, { types.NewFieldType(mysql.TypeDuration), diff --git a/table/tables/index.go b/table/tables/index.go index 5e4ad9ee4c1bc..4536414f019be 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -376,3 +376,12 @@ func (c *index) FetchValues(r []types.Datum, vals []types.Datum) ([]types.Datum, } return vals, nil } + +// IsIndexWritable check whether the index is writable. +func IsIndexWritable(idx table.Index) bool { + s := idx.Meta().State + if s != model.StateDeleteOnly && s != model.StateDeleteReorganization { + return true + } + return false +} diff --git a/table/tables/tables.go b/table/tables/tables.go index 51230c2446875..622a35d0ea8a3 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -606,7 +606,7 @@ func (t *TableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. } // genIndexKeyStr generates index content string representation. -func (t *TableCommon) genIndexKeyStr(colVals []types.Datum) (string, error) { +func genIndexKeyStr(colVals []types.Datum) (string, error) { // Pass pre-composed error to txn. strVals := make([]string, 0, len(colVals)) for _, cv := range colVals { @@ -658,7 +658,7 @@ func (t *TableCommon) addIndices(sctx sessionctx.Context, recordID int64, r []ty } var dupErr error if !skipCheck && v.Meta().Unique { - entryKey, err := t.genIndexKeyStr(indexVals) + entryKey, err := genIndexKeyStr(indexVals) if err != nil { return 0, err } @@ -913,7 +913,7 @@ func (t *TableCommon) buildIndexForRow(ctx sessionctx.Context, rm kv.RetrieverMu if _, err := idx.Create(ctx, rm, vals, h, opts...); err != nil { if kv.ErrKeyExists.Equal(err) { // Make error message consistent with MySQL. - entryKey, err1 := t.genIndexKeyStr(vals) + entryKey, err1 := genIndexKeyStr(vals) if err1 != nil { // if genIndexKeyStr failed, return the original error. return err @@ -1232,6 +1232,68 @@ func CheckHandleExists(ctx context.Context, sctx sessionctx.Context, t table.Tab return nil } +// CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore check whether recordID key or unique index key exists. if not exists, return nil, +// otherwise return kv.ErrKeyExists error. +func CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore(ctx context.Context, sctx sessionctx.Context, t table.Table, recordID int64, data []types.Datum, modified []bool) error { + if pt, ok := t.(*partitionedTable); ok { + info := t.Meta().GetPartitionInfo() + pid, err := pt.locatePartition(sctx, info, data) + if err != nil { + return err + } + t = pt.GetPartition(pid) + } + txn, err := sctx.Txn(true) + if err != nil { + return err + } + shouldSkipIgnoreCheck := func(idx table.Index) bool { + if !IsIndexWritable(idx) || !idx.Meta().Unique { + return true + } + for _, c := range idx.Meta().Columns { + if modified[c.Offset] { + return false + } + } + return true + } + for _, idx := range t.Indices() { + if shouldSkipIgnoreCheck(idx) { + continue + } + vals, err := idx.FetchValues(data, nil) + if err != nil { + return err + } + key, _, err := idx.GenIndexKey(sctx.GetSessionVars().StmtCtx, vals, recordID, nil) + if err != nil { + return err + } + entryKey, err := genIndexKeyStr(vals) + if err != nil { + return err + } + err = func() error { + existErrInfo := kv.NewExistErrInfo(idx.Meta().Name.String(), entryKey) + txn.SetOption(kv.PresumeKeyNotExistsError, existErrInfo) + txn.SetOption(kv.CheckExists, sctx.GetSessionVars().StmtCtx.CheckKeyExists) + defer txn.DelOption(kv.PresumeKeyNotExistsError) + _, err = txn.Get(ctx, key) + if err == nil { + return existErrInfo.Err() + } else if !kv.ErrNotExist.Equal(err) { + return err + } + return nil + }() + if err != nil { + return err + } + } + return nil +} + func init() { table.TableFromMeta = TableFromMeta table.MockTableFromMeta = MockTableFromMeta diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 9179955af674f..7c480d4ab8751 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -512,11 +512,11 @@ func unflatten(datum types.Datum, ft *types.FieldType, loc *time.Location) (type case mysql.TypeFloat: datum.SetFloat32(float32(datum.GetFloat64())) return datum, nil - case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, + mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: datum.SetString(datum.GetString(), ft.Collate) case mysql.TypeTiny, mysql.TypeShort, mysql.TypeYear, mysql.TypeInt24, - mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeTinyBlob, - mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: + mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble: return datum, nil case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: t := types.NewTime(types.ZeroCoreTime, ft.Tp, int8(ft.Decimal)) diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index b56c3cd6b9901..3a27b01ae8be9 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -175,6 +175,16 @@ func (s *testTableCodecSuite) TestUnflattenDatums(c *C) { cmp, err := input[0].CompareDatum(sc, &output[0]) c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) + + input = []types.Datum{types.NewCollationStringDatum("aaa", "utf8mb4_unicode_ci", 0)} + tps = []*types.FieldType{types.NewFieldType(mysql.TypeBlob)} + tps[0].Collate = "utf8mb4_unicode_ci" + output, err = UnflattenDatums(input, tps, sc.TimeZone) + c.Assert(err, IsNil) + cmp, err = input[0].CompareDatum(sc, &output[0]) + c.Assert(err, IsNil) + c.Assert(cmp, Equals, 0) + c.Assert(output[0].Collation(), Equals, "utf8mb4_unicode_ci") } func (s *testTableCodecSuite) TestTimeCodec(c *C) { diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index 318e1ba63f48f..eb380e9503ac7 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -14,6 +14,7 @@ package telemetry import ( + "context" "regexp" "sort" "strings" @@ -66,7 +67,12 @@ func normalizeFieldName(name string) string { } func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index fdb1be6bafc27..46b8cfb8f7b47 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -14,6 +14,8 @@ package telemetry import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/sqlexec" @@ -33,7 +35,12 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/types/datum.go b/types/datum.go index c48e5637ab460..f0d14c7520eae 100644 --- a/types/datum.go +++ b/types/datum.go @@ -194,7 +194,10 @@ var sink = func(s string) { // GetBytes gets bytes value. func (d *Datum) GetBytes() []byte { - return d.b + if d.b != nil { + return d.b + } + return []byte{} } // SetBytes sets bytes value to datum. @@ -235,6 +238,22 @@ func (d *Datum) SetMinNotNull() { d.x = nil } +// GetBinaryLiteral4Cmp gets Bit value, and remove it's prefix 0 for comparison. +func (d *Datum) GetBinaryLiteral4Cmp() BinaryLiteral { + bitLen := len(d.b) + if bitLen == 0 { + return d.b + } + for i := 0; i < bitLen; i++ { + // Remove the prefix 0 in the bit array. + if d.b[i] != 0 { + return d.b[i:] + } + } + // The result is 0x000...00, we just the return 0x00. + return d.b[bitLen-1:] +} + // GetBinaryLiteral gets Bit value func (d *Datum) GetBinaryLiteral() BinaryLiteral { return d.b @@ -249,6 +268,7 @@ func (d *Datum) GetMysqlBit() BinaryLiteral { func (d *Datum) SetBinaryLiteral(b BinaryLiteral) { d.k = KindBinaryLiteral d.b = b + d.collation = charset.CollationBin } // SetMysqlBit sets MysqlBit value @@ -569,7 +589,7 @@ func (d *Datum) CompareDatum(sc *stmtctx.StatementContext, ad *Datum) (int, erro case KindMysqlEnum: return d.compareMysqlEnum(sc, ad.GetMysqlEnum()) case KindBinaryLiteral, KindMysqlBit: - return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral()) + return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral4Cmp()) case KindMysqlSet: return d.compareMysqlSet(sc, ad.GetMysqlSet()) case KindMysqlJSON: @@ -638,7 +658,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er fVal := d.GetMysqlEnum().ToNumber() return CompareFloat64(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt(sc) + val, err := d.GetBinaryLiteral4Cmp().ToInt(sc) fVal := float64(val) return CompareFloat64(fVal, f), errors.Trace(err) case KindMysqlSet: @@ -675,7 +695,7 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, retCollati case KindMysqlEnum: return CompareString(d.GetMysqlEnum().String(), s, d.collation), nil case KindBinaryLiteral, KindMysqlBit: - return CompareString(d.GetBinaryLiteral().ToString(), s, d.collation), nil + return CompareString(d.GetBinaryLiteral4Cmp().ToString(), s, d.collation), nil default: fVal, err := StrToFloat(sc, s, false) if err != nil { @@ -747,9 +767,9 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter case KindMaxValue: return 1, nil case KindString, KindBytes: - return CompareString(d.GetString(), b.ToString(), d.collation), nil + fallthrough // in this case, d is converted to Binary and then compared with b case KindBinaryLiteral, KindMysqlBit: - return CompareString(d.GetBinaryLiteral().ToString(), b.ToString(), d.collation), nil + return CompareString(d.GetBinaryLiteral4Cmp().ToString(), b.ToString(), d.collation), nil default: val, err := b.ToInt(sc) if err != nil { @@ -1424,6 +1444,10 @@ func (d *Datum) convertToMysqlFloatYear(sc *stmtctx.StatementContext, target *Fi y = float64(d.GetMysqlTime().Year()) case KindMysqlDuration: y = float64(time.Now().Year()) + case KindNull: + // if datum is NULL, we should keep it as it is, instead of setting it to zero or any other value. + ret = *d + return ret, nil default: ret, err = d.convertToFloat(sc, NewFieldType(mysql.TypeDouble)) if err != nil { @@ -2136,14 +2160,10 @@ func GetMaxValue(ft *FieldType) (max Datum) { max.SetFloat32(float32(GetMaxFloat(ft.Flen, ft.Decimal))) case mysql.TypeDouble: max.SetFloat64(GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: // codec.Encode KindMaxValue, to avoid import circle bytes := []byte{250} max.SetString(string(bytes), ft.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - // codec.Encode KindMaxValue, to avoid import circle - bytes := []byte{250} - max.SetBytes(bytes) case mysql.TypeNewDecimal: max.SetMysqlDecimal(NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) case mysql.TypeDuration: @@ -2171,14 +2191,10 @@ func GetMinValue(ft *FieldType) (min Datum) { min.SetFloat32(float32(-GetMaxFloat(ft.Flen, ft.Decimal))) case mysql.TypeDouble: min.SetFloat64(-GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: // codec.Encode KindMinNotNull, to avoid import circle bytes := []byte{1} min.SetString(string(bytes), ft.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - // codec.Encode KindMinNotNull, to avoid import circle - bytes := []byte{1} - min.SetBytes(bytes) case mysql.TypeNewDecimal: min.SetMysqlDecimal(NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) case mysql.TypeDuration: diff --git a/types/field_type.go b/types/field_type.go index b93147564c5fe..5804e38c3069e 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -133,8 +133,8 @@ func AggregateEvalType(fts []*FieldType, flag *uint) EvalType { } lft = rft } - setTypeFlag(flag, mysql.UnsignedFlag, unsigned) - setTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString) + SetTypeFlag(flag, mysql.UnsignedFlag, unsigned) + SetTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString) return aggregatedEvalType } @@ -159,7 +159,8 @@ func mergeEvalType(lhs, rhs EvalType, lft, rft *FieldType, isLHSUnsigned, isRHSU return ETInt } -func setTypeFlag(flag *uint, flagItem uint, on bool) { +// SetTypeFlag turns the flagItem on or off. +func SetTypeFlag(flag *uint, flagItem uint, on bool) { if on { *flag |= flagItem } else { @@ -259,8 +260,8 @@ func DefaultTypeForValue(value interface{}, tp *FieldType, char string, collate tp.Flag |= mysql.UnsignedFlag SetBinChsClnFlag(tp) case BinaryLiteral: - tp.Tp = mysql.TypeBit - tp.Flen = len(x) * 8 + tp.Tp = mysql.TypeVarString + tp.Flen = len(x) tp.Decimal = 0 SetBinChsClnFlag(tp) tp.Flag &= ^mysql.BinaryFlag @@ -329,7 +330,7 @@ func DefaultCharsetForType(tp byte) (string, string) { // This is used in hybrid field type expression. // For example "select case c when 1 then 2 when 2 then 'tidb' from t;" // The result field type of the case expression is the merged type of the two when clause. -// See https://github.com/mysql/mysql-server/blob/5.7/sql/field.cc#L1042 +// See https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L1042 func MergeFieldType(a byte, b byte) byte { ia := getFieldTypeIndex(a) ib := getFieldTypeIndex(b) @@ -357,6 +358,7 @@ const ( fieldTypeNum = fieldTypeTearFrom + (255 - fieldTypeTearTo) ) +// https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L248 var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ /* mysql.TypeDecimal -> */ { @@ -409,9 +411,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeTiny, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -442,9 +444,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeShort, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -475,9 +477,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeLong, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -508,9 +510,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeFloat, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeDouble, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeDouble, mysql.TypeVarchar, @@ -541,9 +543,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeDouble, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeDouble, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeDouble, mysql.TypeVarchar, @@ -640,9 +642,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeLonglong, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeNewDate, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -673,9 +675,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeInt24, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeNewDate, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -805,9 +807,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeYear, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -888,29 +890,29 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ }, /* mysql.TypeBit -> */ { - //mysql.TypeDecimal mysql.TypeTiny - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeShort mysql.TypeLong - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeFloat mysql.TypeDouble - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeNull mysql.TypeTimestamp + // mysql.TypeUnspecified mysql.TypeTiny + mysql.TypeVarchar, mysql.TypeLonglong, + // mysql.TypeShort mysql.TypeLong + mysql.TypeLonglong, mysql.TypeLonglong, + // mysql.TypeFloat mysql.TypeDouble + mysql.TypeDouble, mysql.TypeDouble, + // mysql.TypeNull mysql.TypeTimestamp mysql.TypeBit, mysql.TypeVarchar, - //mysql.TypeLonglong mysql.TypeInt24 - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeDate mysql.TypeTime - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeDatetime mysql.TypeYear + // mysql.TypeLonglong mysql.TypeInt24 + mysql.TypeLonglong, mysql.TypeLonglong, + // mysql.TypeDate mysql.TypeTime mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeNewDate mysql.TypeVarchar + // mysql.TypeDatetime mysql.TypeYear + mysql.TypeVarchar, mysql.TypeLonglong, + // mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, //mysql.TypeBit <16>-<244> mysql.TypeBit, //mysql.TypeJSON mysql.TypeVarchar, - //mysql.TypeNewDecimal mysql.TypeEnum - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeSet mysql.TypeTinyBlob + // mysql.TypeNewDecimal mysql.TypeEnum + mysql.TypeNewDecimal, mysql.TypeVarchar, + // mysql.TypeSet mysql.TypeTinyBlob mysql.TypeVarchar, mysql.TypeTinyBlob, //mysql.TypeMediumBlob mysql.TypeLongBlob mysql.TypeMediumBlob, mysql.TypeLongBlob, @@ -970,9 +972,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeNewDecimal, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeNewDecimal, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, diff --git a/types/field_type_test.go b/types/field_type_test.go index 4d2583ec97d97..821ef86b36899 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -270,9 +270,11 @@ func (s *testFieldTypeSuite) TestAggFieldType(c *C) { c.Assert(aggTp.Tp, Equals, mysql.TypeDouble) case mysql.TypeTimestamp, mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeVarchar, - mysql.TypeBit, mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, + mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeVarString, mysql.TypeGeometry: c.Assert(aggTp.Tp, Equals, mysql.TypeVarchar) + case mysql.TypeBit: + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) case mysql.TypeString: c.Assert(aggTp.Tp, Equals, mysql.TypeString) case mysql.TypeDecimal, mysql.TypeNewDecimal: diff --git a/types/time.go b/types/time.go index 842c62daaf5f8..3e86a1e9b88a0 100644 --- a/types/time.go +++ b/types/time.go @@ -2864,6 +2864,9 @@ var dateFormatParserTable = map[string]dateFormatParser{ "%S": secondsNumeric, // Seconds (00..59) "%T": time24Hour, // Time, 24-hour (hh:mm:ss) "%Y": yearNumericFourDigits, // Year, numeric, four digits + "%#": skipAllNums, // Skip all numbers + "%.": skipAllPunct, // Skip all punctation characters + "%@": skipAllAlpha, // Skip all alpha characters // Deprecated since MySQL 5.7.5 "%y": yearNumericTwoDigits, // Year, numeric (two digits) // TODO: Add the following... @@ -3242,3 +3245,39 @@ func DateTimeIsOverflow(sc *stmtctx.StatementContext, date Time) (bool, error) { inRange := (t.After(b) || t.Equal(b)) && (t.Before(e) || t.Equal(e)) return !inRange, nil } + +func skipAllNums(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsNumber(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} + +func skipAllPunct(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsPunct(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} + +func skipAllAlpha(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsLetter(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} diff --git a/util/admin/admin.go b/util/admin/admin.go index a5f46de92f9aa..7419a1ab6d0b4 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -16,12 +16,12 @@ package admin import ( "context" "encoding/json" - "fmt" "math" "sort" "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/errno" @@ -288,13 +288,13 @@ type RecordData struct { Values []types.Datum } -func getCount(ctx sessionctx.Context, sql string) (int64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) +func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) { + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return 0, errors.Trace(err) } if len(rows) != 1 { - return 0, errors.Errorf("can not get count, sql %s result rows %d", sql, len(rows)) + return 0, errors.Errorf("can not get count, rows count = %d", len(rows)) } return rows[0].GetInt64(0), nil } @@ -313,14 +313,34 @@ const ( // otherwise it returns an error and the corresponding index's offset. func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) (byte, int, error) { // Add `` for some names like `table name`. - sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName) - tblCnt, err := getCount(ctx, sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) + if err != nil { + return 0, 0, errors.Trace(err) + } + + var snapshot uint64 + txn, err := ctx.Txn(false) + if err != nil { + return 0, 0, err + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = ctx.GetSessionVars().SnapshotTS + } + + tblCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, 0, errors.Trace(err) } for i, idx := range indices { - sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx) - idxCnt, err := getCount(ctx, sql) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) + if err != nil { + return 0, i, errors.Trace(err) + } + idxCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/chunk/row.go b/util/chunk/row.go index 993ec9b58b9d1..0951de6803900 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -148,14 +148,10 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { if !r.IsNull(colIdx) { d.SetFloat64(r.GetFloat64(colIdx)) } - case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString: + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: if !r.IsNull(colIdx) { d.SetString(r.GetString(colIdx), tp.Collate) } - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - if !r.IsNull(colIdx) { - d.SetBytes(r.GetBytes(colIdx)) - } case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if !r.IsNull(colIdx) { d.SetMysqlTime(r.GetTime(colIdx)) diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 1d8588982364e..537af6599d699 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -24,6 +24,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" ) @@ -43,17 +44,15 @@ var ( type ExecDetails struct { CalleeAddress string CopTime time.Duration - ProcessTime time.Duration - WaitTime time.Duration BackoffTime time.Duration LockKeysDuration time.Duration BackoffSleep map[string]time.Duration BackoffTimes map[string]int RequestCount int - TotalKeys int64 - ProcessedKeys int64 CommitDetail *CommitDetails LockKeysDetail *LockKeysDetails + ScanDetail *ScanDetail + TimeDetail TimeDetail } type stmtExecDetailKeyType struct{} @@ -169,6 +168,89 @@ func (ld *LockKeysDetails) Clone() *LockKeysDetails { return lock } +// TimeDetail contains coprocessor time detail information. +type TimeDetail struct { + // WaitWallTimeMs is the off-cpu wall time which is elapsed in TiKV side. Usually this includes queue waiting time and + // other kind of waitings in series. + ProcessTime time.Duration + // Off-cpu and on-cpu wall time elapsed to actually process the request payload. It does not + // include `wait_wall_time`. + // This field is very close to the CPU time in most cases. Some wait time spend in RocksDB + // cannot be excluded for now, like Mutex wait time, which is included in this field, so that + // this field is called wall time instead of CPU time. + WaitTime time.Duration +} + +// String implements the fmt.Stringer interface. +func (td *TimeDetail) String() string { + if td == nil { + return "" + } + buf := bytes.NewBuffer(make([]byte, 0, 16)) + if td.ProcessTime > 0 { + buf.WriteString("total_process_time: ") + buf.WriteString(FormatDuration(td.ProcessTime)) + } + if td.WaitTime > 0 { + if buf.Len() > 0 { + buf.WriteString(", ") + } + buf.WriteString("total_wait_time: ") + buf.WriteString(FormatDuration(td.WaitTime)) + } + return buf.String() +} + +// MergeFromTimeDetail merges time detail from pb into itself. +func (td *TimeDetail) MergeFromTimeDetail(timeDetail *kvrpcpb.TimeDetail) { + if timeDetail != nil { + td.WaitTime += time.Duration(timeDetail.WaitWallTimeMs) * time.Millisecond + td.ProcessTime += time.Duration(timeDetail.ProcessWallTimeMs) * time.Millisecond + } +} + +// ScanDetail contains coprocessor scan detail information. +type ScanDetail struct { + // TotalKeys is the approximate number of MVCC keys meet during scanning. It includes + // deleted versions, but does not include RocksDB tombstone keys. + TotalKeys int64 + // ProcessedKeys is the number of user keys scanned from the storage. + // It does not include deleted version or RocksDB tombstone keys. + // For Coprocessor requests, it includes keys that has been filtered out by Selection. + ProcessedKeys int64 +} + +// Merge merges scan detail execution details into self. +func (sd *ScanDetail) Merge(scanDetail *ScanDetail) { + atomic.AddInt64(&sd.TotalKeys, scanDetail.TotalKeys) + atomic.AddInt64(&sd.ProcessedKeys, scanDetail.ProcessedKeys) +} + +var zeroScanDetail = ScanDetail{} + +// String implements the fmt.Stringer interface. +func (sd *ScanDetail) String() string { + if sd == nil || *sd == zeroScanDetail { + return "" + } + buf := bytes.NewBuffer(make([]byte, 0, 16)) + buf.WriteString("scan_detail: {") + buf.WriteString("total_process_keys: ") + buf.WriteString(strconv.FormatInt(sd.ProcessedKeys, 10)) + buf.WriteString(", total_keys: ") + buf.WriteString(strconv.FormatInt(sd.TotalKeys, 10)) + buf.WriteString("}") + return buf.String() +} + +// MergeFromScanDetailV2 merges scan detail from pb into itself. +func (sd *ScanDetail) MergeFromScanDetailV2(scanDetail *kvrpcpb.ScanDetailV2) { + if scanDetail != nil { + sd.TotalKeys += int64(scanDetail.TotalVersions) + sd.ProcessedKeys += int64(scanDetail.ProcessedVersions) + } +} + const ( // CopTimeStr represents the sum of cop-task time spend in TiDB distSQL. CopTimeStr = "Cop_time" @@ -218,11 +300,11 @@ func (d ExecDetails) String() string { if d.CopTime > 0 { parts = append(parts, CopTimeStr+": "+strconv.FormatFloat(d.CopTime.Seconds(), 'f', -1, 64)) } - if d.ProcessTime > 0 { - parts = append(parts, ProcessTimeStr+": "+strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)) + if d.TimeDetail.ProcessTime > 0 { + parts = append(parts, ProcessTimeStr+": "+strconv.FormatFloat(d.TimeDetail.ProcessTime.Seconds(), 'f', -1, 64)) } - if d.WaitTime > 0 { - parts = append(parts, WaitTimeStr+": "+strconv.FormatFloat(d.WaitTime.Seconds(), 'f', -1, 64)) + if d.TimeDetail.WaitTime > 0 { + parts = append(parts, WaitTimeStr+": "+strconv.FormatFloat(d.TimeDetail.WaitTime.Seconds(), 'f', -1, 64)) } if d.BackoffTime > 0 { parts = append(parts, BackoffTimeStr+": "+strconv.FormatFloat(d.BackoffTime.Seconds(), 'f', -1, 64)) @@ -233,11 +315,14 @@ func (d ExecDetails) String() string { if d.RequestCount > 0 { parts = append(parts, RequestCountStr+": "+strconv.FormatInt(int64(d.RequestCount), 10)) } - if d.TotalKeys > 0 { - parts = append(parts, TotalKeysStr+": "+strconv.FormatInt(d.TotalKeys, 10)) - } - if d.ProcessedKeys > 0 { - parts = append(parts, ProcessKeysStr+": "+strconv.FormatInt(d.ProcessedKeys, 10)) + scanDetail := d.ScanDetail + if scanDetail != nil { + if scanDetail.TotalKeys > 0 { + parts = append(parts, TotalKeysStr+": "+strconv.FormatInt(scanDetail.TotalKeys, 10)) + } + if scanDetail.ProcessedKeys > 0 { + parts = append(parts, ProcessKeysStr+": "+strconv.FormatInt(scanDetail.ProcessedKeys, 10)) + } } commitDetails := d.CommitDetail if commitDetails != nil { @@ -292,11 +377,11 @@ func (d ExecDetails) ToZapFields() (fields []zap.Field) { if d.CopTime > 0 { fields = append(fields, zap.String(strings.ToLower(CopTimeStr), strconv.FormatFloat(d.CopTime.Seconds(), 'f', -1, 64)+"s")) } - if d.ProcessTime > 0 { - fields = append(fields, zap.String(strings.ToLower(ProcessTimeStr), strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)+"s")) + if d.TimeDetail.ProcessTime > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessTimeStr), strconv.FormatFloat(d.TimeDetail.ProcessTime.Seconds(), 'f', -1, 64)+"s")) } - if d.WaitTime > 0 { - fields = append(fields, zap.String(strings.ToLower(WaitTimeStr), strconv.FormatFloat(d.WaitTime.Seconds(), 'f', -1, 64)+"s")) + if d.TimeDetail.WaitTime > 0 { + fields = append(fields, zap.String(strings.ToLower(WaitTimeStr), strconv.FormatFloat(d.TimeDetail.WaitTime.Seconds(), 'f', -1, 64)+"s")) } if d.BackoffTime > 0 { fields = append(fields, zap.String(strings.ToLower(BackoffTimeStr), strconv.FormatFloat(d.BackoffTime.Seconds(), 'f', -1, 64)+"s")) @@ -304,11 +389,11 @@ func (d ExecDetails) ToZapFields() (fields []zap.Field) { if d.RequestCount > 0 { fields = append(fields, zap.String(strings.ToLower(RequestCountStr), strconv.FormatInt(int64(d.RequestCount), 10))) } - if d.TotalKeys > 0 { - fields = append(fields, zap.String(strings.ToLower(TotalKeysStr), strconv.FormatInt(d.TotalKeys, 10))) + if d.ScanDetail != nil && d.ScanDetail.TotalKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(TotalKeysStr), strconv.FormatInt(d.ScanDetail.TotalKeys, 10))) } - if d.ProcessedKeys > 0 { - fields = append(fields, zap.String(strings.ToLower(ProcessKeysStr), strconv.FormatInt(d.ProcessedKeys, 10))) + if d.ScanDetail != nil && d.ScanDetail.ProcessedKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessKeysStr), strconv.FormatInt(d.ScanDetail.ProcessedKeys, 10))) } commitDetails := d.CommitDetail if commitDetails != nil { @@ -363,7 +448,8 @@ type CopRuntimeStats struct { // have many region leaders, several coprocessor tasks can be sent to the // same tikv-server instance. We have to use a list to maintain all tasks // executed on each instance. - stats map[string][]*BasicRuntimeStats + stats map[string][]*BasicRuntimeStats + scanDetail *ScanDetail } // RecordOneCopTask records a specific cop tasks's execution detail. @@ -412,6 +498,13 @@ func (crs *CopRuntimeStats) String() string { FormatDuration(procTimes[n-1]), FormatDuration(procTimes[0]), FormatDuration(procTimes[n*4/5]), FormatDuration(procTimes[n*19/20]), totalIters, totalTasks)) } + if crs.scanDetail != nil { + detail := crs.scanDetail.String() + if detail != "" { + buf.WriteString(", ") + buf.WriteString(detail) + } + } return buf.String() } @@ -624,7 +717,10 @@ func (e *RuntimeStatsColl) GetCopStats(planID int) *CopRuntimeStats { defer e.mu.Unlock() copStats, ok := e.copStats[planID] if !ok { - copStats = &CopRuntimeStats{stats: make(map[string][]*BasicRuntimeStats)} + copStats = &CopRuntimeStats{ + stats: make(map[string][]*BasicRuntimeStats), + scanDetail: &ScanDetail{}, + } e.copStats[planID] = copStats } return copStats @@ -651,6 +747,12 @@ func (e *RuntimeStatsColl) RecordOneCopTask(planID int, address string, summary copStats.RecordOneCopTask(address, summary) } +// RecordScanDetail records a specific cop tasks's cop detail. +func (e *RuntimeStatsColl) RecordScanDetail(planID int, detail *ScanDetail) { + copStats := e.GetCopStats(planID) + copStats.scanDetail.Merge(detail) +} + // ExistsRootStats checks if the planID exists in the rootStats collection. func (e *RuntimeStatsColl) ExistsRootStats(planID int) bool { e.mu.Lock() diff --git a/util/execdetails/execdetails_test.go b/util/execdetails/execdetails_test.go index af18e52b9be23..0e29f0a7850ef 100644 --- a/util/execdetails/execdetails_test.go +++ b/util/execdetails/execdetails_test.go @@ -15,6 +15,7 @@ package execdetails import ( "fmt" + "strconv" "sync" "testing" "time" @@ -30,13 +31,9 @@ func TestT(t *testing.T) { func TestString(t *testing.T) { detail := &ExecDetails{ - CopTime: time.Second + 3*time.Millisecond, - ProcessTime: 2*time.Second + 5*time.Millisecond, - WaitTime: time.Second, - BackoffTime: time.Second, - RequestCount: 1, - TotalKeys: 100, - ProcessedKeys: 10, + CopTime: time.Second + 3*time.Millisecond, + BackoffTime: time.Second, + RequestCount: 1, CommitDetail: &CommitDetails{ GetCommitTsTime: time.Second, PrewriteTime: time.Second, @@ -60,6 +57,14 @@ func TestString(t *testing.T) { PrewriteRegionNum: 1, TxnRetry: 1, }, + ScanDetail: &ScanDetail{ + ProcessedKeys: 10, + TotalKeys: 100, + }, + TimeDetail: TimeDetail{ + ProcessTime: 2*time.Second + 5*time.Millisecond, + WaitTime: time.Second, + }, } expected := "Cop_time: 1.003 Process_time: 2.005 Wait_time: 1 Backoff_time: 1 Request_count: 1 Total_keys: 100 Process_keys: 10 Prewrite_time: 1 Commit_time: 1 " + "Get_commit_ts_time: 1 Commit_backoff_time: 1 Backoff_types: [backoff1 backoff2] Resolve_lock_time: 1 Local_latch_wait_time: 1 Write_keys: 1 Write_size: 1 Prewrite_region: 1 Txn_retry: 1" @@ -91,12 +96,74 @@ func TestCopRuntimeStats(t *testing.T) { stats.RecordOneCopTask(tableScanID, "8.8.8.9", mockExecutorExecutionSummary(2, 2, 2)) stats.RecordOneCopTask(aggID, "8.8.8.8", mockExecutorExecutionSummary(3, 3, 3)) stats.RecordOneCopTask(aggID, "8.8.8.9", mockExecutorExecutionSummary(4, 4, 4)) + scanDetail := &ScanDetail{ + TotalKeys: 15, + ProcessedKeys: 10, + } + stats.RecordScanDetail(tableScanID, scanDetail) + if stats.ExistsCopStats(tableScanID) != true { + t.Fatal("exist") + } + cop := stats.GetCopStats(tableScanID) + if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}, "+ + "scan_detail: {total_process_keys: 10, total_keys: 15}" { + t.Fatalf(cop.String()) + } + copStats := cop.stats["8.8.8.8"] + if copStats == nil { + t.Fatal("cop stats is nil") + } + copStats[0].SetRowNum(10) + copStats[0].Record(time.Second, 10) + if copStats[0].String() != "time:1s, loops:2" { + t.Fatalf("cop stats string is not expect, got: %v", copStats[0].String()) + } + + if stats.GetCopStats(aggID).String() != "tikv_task:{proc max:4ns, min:3ns, p80:4ns, p95:4ns, iters:7, tasks:2}" { + t.Fatalf("agg cop stats string is not as expected, got: %v", stats.GetCopStats(aggID).String()) + } + rootStats := stats.GetRootStats(tableReaderID) + if rootStats == nil { + t.Fatal("table_reader") + } + if stats.ExistsRootStats(tableReaderID) == false { + t.Fatal("table_reader not exists") + } + + cop.scanDetail.ProcessedKeys = 0 + // Print all fields even though the value of some fields is 0. + if cop.String() != "tikv_task:{proc max:1s, min:2ns, p80:1s, p95:1s, iters:4, tasks:2}, "+ + "scan_detail: {total_process_keys: 0, total_keys: 15}" { + t.Fatalf(cop.String()) + } + + zeroScanDetail := ScanDetail{} + if zeroScanDetail.String() != "" { + t.Fatalf(zeroScanDetail.String()) + } +} + +func TestCopRuntimeStatsForTiFlash(t *testing.T) { + stats := NewRuntimeStatsColl() + tableScanID := 1 + aggID := 2 + tableReaderID := 3 + stats.RecordOneCopTask(aggID, "8.8.8.8", mockExecutorExecutionSummaryForTiFlash(1, 1, 1, "tablescan_"+strconv.Itoa(tableScanID))) + stats.RecordOneCopTask(aggID, "8.8.8.9", mockExecutorExecutionSummaryForTiFlash(2, 2, 2, "tablescan_"+strconv.Itoa(tableScanID))) + stats.RecordOneCopTask(tableScanID, "8.8.8.8", mockExecutorExecutionSummaryForTiFlash(3, 3, 3, "aggregation_"+strconv.Itoa(aggID))) + stats.RecordOneCopTask(tableScanID, "8.8.8.9", mockExecutorExecutionSummaryForTiFlash(4, 4, 4, "aggregation_"+strconv.Itoa(aggID))) + scanDetail := &ScanDetail{ + TotalKeys: 10, + ProcessedKeys: 10, + } + stats.RecordScanDetail(tableScanID, scanDetail) if stats.ExistsCopStats(tableScanID) != true { t.Fatal("exist") } cop := stats.GetCopStats(tableScanID) - if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}" { - t.Fatal("table_scan") + if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}"+ + ", scan_detail: {total_process_keys: 10, total_keys: 10}" { + t.Fatal(cop.String()) } copStats := cop.stats["8.8.8.8"] if copStats == nil { diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index f265e4dd0603f..6d3c116cdb0a4 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -14,7 +14,7 @@ package gcutil import ( - "fmt" + "context" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -25,18 +25,20 @@ import ( ) const ( - selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name='%s'` - insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'` + insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) + ON DUPLICATE KEY UPDATE variable_value = %?, comment = %?` + selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name=%?` ) // CheckGCEnable is use to check whether GC is enable. func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_enable") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) - if err != nil { - return false, errors.Trace(err) + stmt, err1 := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_enable") + if err1 != nil { + return false, errors.Trace(err1) + } + rows, _, err2 := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + if err1 != nil { + return false, errors.Trace(err2) } if len(rows) != 1 { return false, errors.New("can not get 'tikv_gc_enable'") @@ -46,15 +48,19 @@ func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { // DisableGC will disable GC enable variable. func DisableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status", "false", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } // EnableGC will enable GC enable variable. func EnableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status", "true", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } @@ -80,8 +86,12 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_safe_point") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") + if err != nil { + return 0, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return 0, errors.Trace(err) } diff --git a/util/memory/tracker.go b/util/memory/tracker.go index a8736376ecea1..e630a2b730942 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -430,6 +430,12 @@ func (t *Tracker) DetachFromGlobalTracker() { t.parent = nil } +// ReplaceBytesUsed replace bytesConsume for the tracker +func (t *Tracker) ReplaceBytesUsed(bytes int64) { + t.Consume(-t.BytesConsumed()) + t.Consume(bytes) +} + const ( // LabelForSQLText represents the label of the SQL Text LabelForSQLText int = -1 diff --git a/util/mock/context.go b/util/mock/context.go index 7dded0b330a0b..6661500665f94 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -20,6 +20,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/owner" @@ -29,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/disk" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tipb/go-binlog" ) @@ -61,9 +63,14 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, return nil, errors.Errorf("Not Support.") } +// ExecuteStmt implements sqlexec.SQLExecutor ExecuteStmt interface. +func (c *Context) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") +} + // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. -func (c *Context) ExecuteInternal(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) { - return nil, errors.Errorf("Not Support.") +func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") } type mockDDLOwnerChecker struct{} @@ -206,6 +213,11 @@ func (c *Context) GoCtx() context.Context { // StoreQueryFeedback stores the query feedback. func (c *Context) StoreQueryFeedback(_ interface{}) {} +// GetTxnWriteThroughputSLI implements the sessionctx.Context interface. +func (c *Context) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &sli.TxnWriteThroughputSLI{} +} + // StmtCommit implements the sessionctx.Context interface. func (c *Context) StmtCommit(tracker *memory.Tracker) error { return nil diff --git a/util/rowcodec/decoder.go b/util/rowcodec/decoder.go index 8352fc8fa2efc..cd8f7438578de 100644 --- a/util/rowcodec/decoder.go +++ b/util/rowcodec/decoder.go @@ -133,10 +133,8 @@ func (decoder *DatumMapDecoder) decodeColDatum(col *ColInfo, colData []byte) (ty return d, err } d.SetFloat64(fVal) - case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString: + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: d.SetString(string(colData), col.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - d.SetBytes(colData) case mysql.TypeNewDecimal: _, dec, precision, frac, err := codec.DecodeDecimal(colData) if err != nil { diff --git a/util/rowcodec/rowcodec_test.go b/util/rowcodec/rowcodec_test.go index 04b1c3455f5e1..dbe0c3568b8db 100644 --- a/util/rowcodec/rowcodec_test.go +++ b/util/rowcodec/rowcodec_test.go @@ -308,6 +308,8 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { c.Assert(len(remain), Equals, 0) if d.Kind() == types.KindMysqlDecimal { c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) + } else if d.Kind() == types.KindBytes { + c.Assert(d.GetBytes(), DeepEquals, t.bt.GetBytes()) } else { c.Assert(d, DeepEquals, t.bt) } @@ -341,9 +343,9 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { }, { 24, - types.NewFieldType(mysql.TypeBlob), - types.NewBytesDatum([]byte("abc")), - types.NewBytesDatum([]byte("abc")), + types.NewFieldTypeWithCollation(mysql.TypeBlob, mysql.DefaultCollationName, types.UnspecifiedLength), + types.NewStringDatum("abc"), + types.NewStringDatum("abc"), nil, false, }, @@ -470,8 +472,8 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { testData[0].id = 1 // test large data - testData[3].dt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1))) - testData[3].bt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1))) + testData[3].dt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) + testData[3].bt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) encodeAndDecode(c, testData) } diff --git a/util/sli/sli.go b/util/sli/sli.go new file mode 100644 index 0000000000000..1c48f9a03a510 --- /dev/null +++ b/util/sli/sli.go @@ -0,0 +1,119 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sli + +import ( + "fmt" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/metrics" +) + +// TxnWriteThroughputSLI uses to report transaction write throughput metrics for SLI. +type TxnWriteThroughputSLI struct { + invalid bool + affectRow uint64 + writeSize int + readKeys int + writeKeys int + writeTime time.Duration +} + +// FinishExecuteStmt records the cost for write statement which affect rows more than 0. +// And report metrics when the transaction is committed. +func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow uint64, inTxn bool) { + if affectRow > 0 { + t.writeTime += cost + t.affectRow += affectRow + } + + // Currently not in transaction means the last transaction is finish, should report metrics and reset data. + if !inTxn { + if affectRow == 0 { + // AffectRows is 0 when statement is commit. + t.writeTime += cost + } + // Report metrics after commit this transaction. + t.reportMetric() + + // Skip reset for test. + failpoint.Inject("CheckTxnWriteThroughput", func() { + failpoint.Return() + }) + + // Reset for next transaction. + t.Reset() + } +} + +// AddReadKeys adds the read keys. +func (t *TxnWriteThroughputSLI) AddReadKeys(readKeys int64) { + t.readKeys += int(readKeys) +} + +// AddTxnWriteSize adds the transaction write size and keys. +func (t *TxnWriteThroughputSLI) AddTxnWriteSize(size, keys int) { + t.writeSize += size + t.writeKeys += keys +} + +func (t *TxnWriteThroughputSLI) reportMetric() { + if t.IsInvalid() { + return + } + if t.IsSmallTxn() { + metrics.SmallTxnWriteDuration.Observe(t.writeTime.Seconds()) + } else { + metrics.TxnWriteThroughput.Observe(float64(t.writeSize) / t.writeTime.Seconds()) + } +} + +// SetInvalid marks this transaction is invalid to report SLI metrics. +func (t *TxnWriteThroughputSLI) SetInvalid() { + t.invalid = true +} + +// IsInvalid checks the transaction is valid to report SLI metrics. Currently, the following case will cause invalid: +// 1. The transaction contains `insert|replace into ... select ... from ...` statement. +// 2. The write SQL statement has more read keys than write keys. +func (t *TxnWriteThroughputSLI) IsInvalid() bool { + return t.invalid || t.readKeys > t.writeKeys || t.writeSize == 0 || t.writeTime == 0 +} + +const ( + smallTxnAffectRow = 20 + smallTxnWriteSize = 1 * 1024 * 1024 // 1MB +) + +// IsSmallTxn exports for testing. +func (t *TxnWriteThroughputSLI) IsSmallTxn() bool { + return t.affectRow <= smallTxnAffectRow && t.writeSize <= smallTxnWriteSize +} + +// Reset exports for testing. +func (t *TxnWriteThroughputSLI) Reset() { + t.invalid = false + t.affectRow = 0 + t.writeSize = 0 + t.readKeys = 0 + t.writeKeys = 0 + t.writeTime = 0 +} + +// String exports for testing. +func (t *TxnWriteThroughputSLI) String() string { + return fmt.Sprintf("invalid: %v, affectRow: %v, writeSize: %v, readKeys: %v, writeKeys: %v, writeTime: %v", + t.invalid, t.affectRow, t.writeSize, t.readKeys, t.writeKeys, t.writeTime.String()) +} diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index ce39db7fd00a7..597873d050151 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -42,6 +42,43 @@ type RestrictedSQLExecutor interface { // If current session sets the snapshot timestamp, then execute with this snapshot timestamp. // Otherwise, execute with the current transaction start timestamp if the transaction is valid. ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast.ResultField, error) + + // The above methods are all deprecated. + // After the refactor finish, they will be removed. + + // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. + // It works like printf() in c, there are following format specifiers: + // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) + // 2. %%: output % + // 3. %n: for identifiers, for example ("use %n", db) + // + // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". + // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. + // This function only saves you from processing potentially unsafe parameters. + ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ExecRestrictedStmt run sql statement in ctx with some restriction. + ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) +} + +// ExecOption is a struct defined for ExecRestrictedSQLWithContext option. +type ExecOption struct { + IgnoreWarning bool + SnapshotTS uint64 +} + +// OptionFuncAlias is defined for the optional paramater of ExecRestrictedSQLWithContext. +type OptionFuncAlias = func(option *ExecOption) + +// ExecOptionIgnoreWarning tells ExecRestrictedSQLWithContext to ignore the warnings. +var ExecOptionIgnoreWarning OptionFuncAlias = func(option *ExecOption) { + option.IgnoreWarning = true +} + +// ExecOptionWithSnapshot tells ExecRestrictedSQLWithContext to use a snapshot. +func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { + return func(option *ExecOption) { + option.SnapshotTS = snapshot + } } // SQLExecutor is an interface provides executing normal sql statement. @@ -51,7 +88,8 @@ type RestrictedSQLExecutor interface { type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. - ExecuteInternal(ctx context.Context, sql string) ([]RecordSet, error) + ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error) + ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (RecordSet, error) } // SQLParser is an interface provides parsing sql statement. diff --git a/util/sqlexec/utils.go b/util/sqlexec/utils.go new file mode 100644 index 0000000000000..1ffc29b72d8e0 --- /dev/null +++ b/util/sqlexec/utils.go @@ -0,0 +1,260 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "io" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" +) + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash will escape string into the buffer, with backslash. +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, hack.Slice(v)) +} + +// escapeSQL is the internal impl of EscapeSQL and FormatSQL. +func escapeSQL(sql string, args ...interface{}) ([]byte, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return nil, errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.Replace(v, "`", "``", -1)...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + case []string: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + case []float32: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32) + } + case []float64: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, k, 'g', -1, 64) + } + default: + return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return buf, nil +} + +// EscapeSQL will escape input arguments into the sql string, doing necessary processing. +// It works like printf() in c, there are following format specifiers: +// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) +// 2. %%: output % +// 3. %n: for identifiers, for example ("use %n", db) +// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". +// It is still your responsibility to write safe SQL. +func EscapeSQL(sql string, args ...interface{}) (string, error) { + str, err := escapeSQL(sql, args...) + return string(str), err +} + +// MustEscapeSQL is an helper around EscapeSQL. The error returned from escapeSQL can be avoided statically if you do not pass interface{}. +func MustEscapeSQL(sql string, args ...interface{}) string { + r, err := EscapeSQL(sql, args...) + if err != nil { + panic(err) + } + return r +} + +// FormatSQL is the io.Writer version of EscapeSQL. Please refer to EscapeSQL for details. +func FormatSQL(w io.Writer, sql string, args ...interface{}) error { + buf, err := escapeSQL(sql, args...) + if err != nil { + return err + } + _, err = w.Write(buf) + return err +} + +// MustFormatSQL is an helper around FormatSQL, like MustEscapeSQL. But it asks that the writer must be strings.Builder, +// which will not return error when w.Write(...). +func MustFormatSQL(w *strings.Builder, sql string, args ...interface{}) { + err := FormatSQL(w, sql, args...) + if err != nil { + panic(err) + } +} diff --git a/util/sqlexec/utils_test.go b/util/sqlexec/utils_test.go new file mode 100644 index 0000000000000..a8a912a33978f --- /dev/null +++ b/util/sqlexec/utils_test.go @@ -0,0 +1,430 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "strings" + "testing" + "time" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testUtilsSuite{}) + +type testUtilsSuite struct{} + +func (s *testUtilsSuite) TestReserveBuffer(c *C) { + res0 := reserveBuffer(nil, 0) + c.Assert(res0, HasLen, 0) + + res1 := reserveBuffer(res0, 3) + c.Assert(res1, HasLen, 3) + res1[1] = 3 + + res2 := reserveBuffer(res1, 9) + c.Assert(res2, HasLen, 12) + c.Assert(cap(res2), Equals, 15) + c.Assert(res2[:3], DeepEquals, res1) +} + +func (s *testUtilsSuite) TestEscapeBackslash(c *C) { + type TestCase struct { + name string + input []byte + output []byte + } + tests := []TestCase{ + { + name: "normal", + input: []byte("hello"), + output: []byte("hello"), + }, + { + name: "0", + input: []byte("he\x00lo"), + output: []byte("he\\0lo"), + }, + { + name: "break line", + input: []byte("he\nlo"), + output: []byte("he\\nlo"), + }, + { + name: "carry", + input: []byte("he\rlo"), + output: []byte("he\\rlo"), + }, + { + name: "substitute", + input: []byte("he\x1alo"), + output: []byte("he\\Zlo"), + }, + { + name: "single quote", + input: []byte("he'lo"), + output: []byte("he\\'lo"), + }, + { + name: "double quote", + input: []byte("he\"lo"), + output: []byte("he\\\"lo"), + }, + { + name: "back slash", + input: []byte("he\\lo"), + output: []byte("he\\\\lo"), + }, + { + name: "double escape", + input: []byte("he\x00lo\""), + output: []byte("he\\0lo\\\""), + }, + { + name: "chinese", + input: []byte("中文?"), + output: []byte("中文?"), + }, + } + for _, t := range tests { + commentf := Commentf("%s", t.name) + c.Assert(escapeBytesBackslash(nil, t.input), DeepEquals, t.output, commentf) + c.Assert(escapeStringBackslash(nil, string(t.input)), DeepEquals, t.output, commentf) + } +} + +func (s *testUtilsSuite) TestEscapeSQL(c *C) { + type TestCase struct { + name string + input string + params []interface{} + output string + err string + } + time2, err := time.Parse("2006-01-02 15:04:05", "2018-01-23 04:03:05") + c.Assert(err, IsNil) + tests := []TestCase{ + { + name: "normal 1", + input: "select * from 1", + params: []interface{}{}, + output: "select * from 1", + err: "", + }, + { + name: "normal 2", + input: "WHERE source != 'builtin'", + params: []interface{}{}, + output: "WHERE source != 'builtin'", + err: "", + }, + { + name: "discard extra arguments", + input: "select * from 1", + params: []interface{}{4, 5, "rt"}, + output: "select * from 1", + err: "", + }, + { + name: "%? missing arguments", + input: "select %? from %?", + params: []interface{}{4}, + err: "missing arguments.*", + }, + { + name: "nil", + input: "select %?", + params: []interface{}{nil}, + output: "select NULL", + err: "", + }, + { + name: "int", + input: "select %?", + params: []interface{}{int(3)}, + output: "select 3", + err: "", + }, + { + name: "int8", + input: "select %?", + params: []interface{}{int8(4)}, + output: "select 4", + err: "", + }, + { + name: "int16", + input: "select %?", + params: []interface{}{int16(5)}, + output: "select 5", + err: "", + }, + { + name: "int32", + input: "select %?", + params: []interface{}{int32(6)}, + output: "select 6", + err: "", + }, + { + name: "int64", + input: "select %?", + params: []interface{}{int64(7)}, + output: "select 7", + err: "", + }, + { + name: "uint", + input: "select %?", + params: []interface{}{uint(8)}, + output: "select 8", + err: "", + }, + { + name: "uint8", + input: "select %?", + params: []interface{}{uint8(9)}, + output: "select 9", + err: "", + }, + { + name: "uint16", + input: "select %?", + params: []interface{}{uint16(10)}, + output: "select 10", + err: "", + }, + { + name: "uint32", + input: "select %?", + params: []interface{}{uint32(11)}, + output: "select 11", + err: "", + }, + { + name: "uint64", + input: "select %?", + params: []interface{}{uint64(12)}, + output: "select 12", + err: "", + }, + { + name: "float32", + input: "select %?", + params: []interface{}{float32(0.13)}, + output: "select 0.13", + err: "", + }, + { + name: "float64", + input: "select %?", + params: []interface{}{float64(0.14)}, + output: "select 0.14", + err: "", + }, + { + name: "bool on", + input: "select %?", + params: []interface{}{true}, + output: "select 1", + err: "", + }, + { + name: "bool off", + input: "select %?", + params: []interface{}{false}, + output: "select 0", + err: "", + }, + { + name: "time 0", + input: "select %?", + params: []interface{}{time.Time{}}, + output: "select '0000-00-00'", + err: "", + }, + { + name: "time 1", + input: "select %?", + params: []interface{}{time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)}, + output: "select '2019-01-01 00:00:00'", + err: "", + }, + { + name: "time 2", + input: "select %?", + params: []interface{}{time2}, + output: "select '2018-01-23 04:03:05'", + err: "", + }, + { + name: "time 3", + input: "select %?", + params: []interface{}{time.Unix(0, 888888888).UTC()}, + output: "select '1970-01-01 00:00:00.888888'", + err: "", + }, + { + name: "empty byte slice1", + input: "select %?", + params: []interface{}{[]byte(nil)}, + output: "select NULL", + err: "", + }, + { + name: "empty byte slice2", + input: "select %?", + params: []interface{}{[]byte{}}, + output: "select _binary''", + err: "", + }, + { + name: "byte slice", + input: "select %?", + params: []interface{}{[]byte{2, 3}}, + output: "select _binary'\x02\x03'", + err: "", + }, + { + name: "string", + input: "select %?", + params: []interface{}{"33"}, + output: "select '33'", + }, + { + name: "string slice", + input: "select %?", + params: []interface{}{[]string{"33", "44"}}, + output: "select '33','44'", + }, + { + name: "raw json", + input: "select %?", + params: []interface{}{json.RawMessage(`{"h": "hello"}`)}, + output: "select '{\\\"h\\\": \\\"hello\\\"}'", + }, + { + name: "unsupported args", + input: "select %?", + params: []interface{}{make(chan byte)}, + err: "unsupported 1-th argument.*", + }, + { + name: "mixed arguments", + input: "select %?, %?, %?", + params: []interface{}{"33", 44, time.Time{}}, + output: "select '33', 44, '0000-00-00'", + }, + { + name: "simple injection", + input: "select %?", + params: []interface{}{"0; drop database"}, + output: "select '0; drop database'", + }, + { + name: "identifier, wrong arg", + input: "use %n", + params: []interface{}{3}, + err: "expect a string identifier.*", + }, + { + name: "identifier", + input: "use %n", + params: []interface{}{"table`"}, + output: "use `table```", + err: "", + }, + { + name: "%n missing arguments", + input: "use %n", + params: []interface{}{}, + err: "missing arguments.*", + }, + { + name: "% escape", + input: "select * from t where val = '%%?'", + params: []interface{}{}, + output: "select * from t where val = '%?'", + err: "", + }, + { + name: "unknown specifier", + input: "%v", + params: []interface{}{}, + output: "%v", + err: "", + }, + { + name: "truncated specifier ", + input: "rv %", + params: []interface{}{}, + output: "rv %", + err: "", + }, + { + name: "float32 slice", + input: "select %?", + params: []interface{}{[]float32{33.1, 0.44}}, + output: "select 33.1,0.44", + }, + { + name: "float64 slice", + input: "select %?", + params: []interface{}{[]float64{55.2, 0.66}}, + output: "select 55.2,0.66", + }, + } + for _, t := range tests { + comment := Commentf("%s", t.name) + r3 := new(strings.Builder) + r1, e1 := escapeSQL(t.input, t.params...) + r2, e2 := EscapeSQL(t.input, t.params...) + e3 := FormatSQL(r3, t.input, t.params...) + if t.err == "" { + c.Assert(e1, IsNil, comment) + c.Assert(string(r1), Equals, t.output, comment) + c.Assert(e2, IsNil, comment) + c.Assert(r2, Equals, t.output, comment) + c.Assert(e3, IsNil, comment) + c.Assert(r3.String(), Equals, t.output, comment) + } else { + c.Assert(e1, NotNil, comment) + c.Assert(e1, ErrorMatches, t.err, comment) + c.Assert(e2, NotNil, comment) + c.Assert(e2, ErrorMatches, t.err, comment) + c.Assert(e3, NotNil, comment) + c.Assert(e3, ErrorMatches, t.err, comment) + } + } +} + +func (s *testUtilsSuite) TestMustUtils(c *C) { + c.Assert(func() { + MustEscapeSQL("%?") + }, PanicMatches, "missing arguments.*") + + c.Assert(func() { + sql := new(strings.Builder) + MustFormatSQL(sql, "%?") + }, PanicMatches, "missing arguments.*") + + sql := new(strings.Builder) + MustFormatSQL(sql, "t") + MustEscapeSQL("tt") +} diff --git a/util/stmtsummary/statement_summary.go b/util/stmtsummary/statement_summary.go index ccd2d8b9b4f70..502b94c874c4f 100644 --- a/util/stmtsummary/statement_summary.go +++ b/util/stmtsummary/statement_summary.go @@ -183,6 +183,7 @@ type stmtSummaryByDigestElement struct { // plan cache planInCache bool planCacheHits int64 + planInBinding bool // pessimistic execution retry information. execRetryCount uint execRetryTime time.Duration @@ -214,6 +215,7 @@ type StmtExecInfo struct { IsInternal bool Succeed bool PlanInCache bool + PlanInBinding bool ExecRetryCount uint ExecRetryTime time.Duration execdetails.StmtExecDetails @@ -627,6 +629,7 @@ func newStmtSummaryByDigestElement(sei *StmtExecInfo, beginTime int64, intervalS authUsers: make(map[string]struct{}), planInCache: false, planCacheHits: 0, + planInBinding: false, prepared: sei.Prepared, } ssElement.add(sei, intervalSeconds) @@ -698,25 +701,28 @@ func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeco } // TiKV - ssElement.sumProcessTime += sei.ExecDetail.ProcessTime - if sei.ExecDetail.ProcessTime > ssElement.maxProcessTime { - ssElement.maxProcessTime = sei.ExecDetail.ProcessTime + ssElement.sumProcessTime += sei.ExecDetail.TimeDetail.ProcessTime + if sei.ExecDetail.TimeDetail.ProcessTime > ssElement.maxProcessTime { + ssElement.maxProcessTime = sei.ExecDetail.TimeDetail.ProcessTime } - ssElement.sumWaitTime += sei.ExecDetail.WaitTime - if sei.ExecDetail.WaitTime > ssElement.maxWaitTime { - ssElement.maxWaitTime = sei.ExecDetail.WaitTime + ssElement.sumWaitTime += sei.ExecDetail.TimeDetail.WaitTime + if sei.ExecDetail.TimeDetail.WaitTime > ssElement.maxWaitTime { + ssElement.maxWaitTime = sei.ExecDetail.TimeDetail.WaitTime } ssElement.sumBackoffTime += sei.ExecDetail.BackoffTime if sei.ExecDetail.BackoffTime > ssElement.maxBackoffTime { ssElement.maxBackoffTime = sei.ExecDetail.BackoffTime } - ssElement.sumTotalKeys += sei.ExecDetail.TotalKeys - if sei.ExecDetail.TotalKeys > ssElement.maxTotalKeys { - ssElement.maxTotalKeys = sei.ExecDetail.TotalKeys - } - ssElement.sumProcessedKeys += sei.ExecDetail.ProcessedKeys - if sei.ExecDetail.ProcessedKeys > ssElement.maxProcessedKeys { - ssElement.maxProcessedKeys = sei.ExecDetail.ProcessedKeys + + if sei.ExecDetail.ScanDetail != nil { + ssElement.sumTotalKeys += sei.ExecDetail.ScanDetail.TotalKeys + if sei.ExecDetail.ScanDetail.TotalKeys > ssElement.maxTotalKeys { + ssElement.maxTotalKeys = sei.ExecDetail.ScanDetail.TotalKeys + } + ssElement.sumProcessedKeys += sei.ExecDetail.ScanDetail.ProcessedKeys + if sei.ExecDetail.ScanDetail.ProcessedKeys > ssElement.maxProcessedKeys { + ssElement.maxProcessedKeys = sei.ExecDetail.ScanDetail.ProcessedKeys + } } // txn @@ -782,6 +788,13 @@ func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeco ssElement.planInCache = false } + // SPM + if sei.PlanInBinding { + ssElement.planInBinding = true + } else { + ssElement.planInBinding = false + } + // other ssElement.sumAffectedRows += sei.StmtCtx.AffectedRows() ssElement.sumMem += sei.MemMax @@ -899,6 +912,7 @@ func (ssElement *stmtSummaryByDigestElement) toDatum(ssbd *stmtSummaryByDigest) types.NewTime(types.FromGoTime(ssElement.lastSeen), mysql.TypeTimestamp, 0), ssElement.planInCache, ssElement.planCacheHits, + ssElement.planInBinding, ssElement.sampleSQL, ssElement.prevSQL, ssbd.planDigest, diff --git a/util/stmtsummary/statement_summary_test.go b/util/stmtsummary/statement_summary_test.go index 39df0a45147bb..0bc7b3ab0cedb 100644 --- a/util/stmtsummary/statement_summary_test.go +++ b/util/stmtsummary/statement_summary_test.go @@ -95,16 +95,16 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { maxCopProcessAddress: stmtExecInfo1.CopTasks.MaxProcessAddress, maxCopWaitTime: stmtExecInfo1.CopTasks.MaxWaitTime, maxCopWaitAddress: stmtExecInfo1.CopTasks.MaxWaitAddress, - sumProcessTime: stmtExecInfo1.ExecDetail.ProcessTime, - maxProcessTime: stmtExecInfo1.ExecDetail.ProcessTime, - sumWaitTime: stmtExecInfo1.ExecDetail.WaitTime, - maxWaitTime: stmtExecInfo1.ExecDetail.WaitTime, + sumProcessTime: stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime, + maxProcessTime: stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime, + sumWaitTime: stmtExecInfo1.ExecDetail.TimeDetail.WaitTime, + maxWaitTime: stmtExecInfo1.ExecDetail.TimeDetail.WaitTime, sumBackoffTime: stmtExecInfo1.ExecDetail.BackoffTime, maxBackoffTime: stmtExecInfo1.ExecDetail.BackoffTime, - sumTotalKeys: stmtExecInfo1.ExecDetail.TotalKeys, - maxTotalKeys: stmtExecInfo1.ExecDetail.TotalKeys, - sumProcessedKeys: stmtExecInfo1.ExecDetail.ProcessedKeys, - maxProcessedKeys: stmtExecInfo1.ExecDetail.ProcessedKeys, + sumTotalKeys: stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + maxTotalKeys: stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + sumProcessedKeys: stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, + maxProcessedKeys: stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, sumGetCommitTsTime: stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime, maxGetCommitTsTime: stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime, sumPrewriteTime: stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime, @@ -176,12 +176,8 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "202", - ProcessTime: 1500, - WaitTime: 150, BackoffTime: 180, RequestCount: 20, - TotalKeys: 6000, - ProcessedKeys: 1500, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 500, PrewriteTime: 50000, @@ -200,6 +196,14 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { PrewriteRegionNum: 100, TxnRetry: 10, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 6000, + ProcessedKeys: 1500, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 1500, + WaitTime: 150, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -224,16 +228,16 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { expectedSummaryElement.maxCopProcessAddress = stmtExecInfo2.CopTasks.MaxProcessAddress expectedSummaryElement.maxCopWaitTime = stmtExecInfo2.CopTasks.MaxWaitTime expectedSummaryElement.maxCopWaitAddress = stmtExecInfo2.CopTasks.MaxWaitAddress - expectedSummaryElement.sumProcessTime += stmtExecInfo2.ExecDetail.ProcessTime - expectedSummaryElement.maxProcessTime = stmtExecInfo2.ExecDetail.ProcessTime - expectedSummaryElement.sumWaitTime += stmtExecInfo2.ExecDetail.WaitTime - expectedSummaryElement.maxWaitTime = stmtExecInfo2.ExecDetail.WaitTime + expectedSummaryElement.sumProcessTime += stmtExecInfo2.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.maxProcessTime = stmtExecInfo2.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.sumWaitTime += stmtExecInfo2.ExecDetail.TimeDetail.WaitTime + expectedSummaryElement.maxWaitTime = stmtExecInfo2.ExecDetail.TimeDetail.WaitTime expectedSummaryElement.sumBackoffTime += stmtExecInfo2.ExecDetail.BackoffTime expectedSummaryElement.maxBackoffTime = stmtExecInfo2.ExecDetail.BackoffTime - expectedSummaryElement.sumTotalKeys += stmtExecInfo2.ExecDetail.TotalKeys - expectedSummaryElement.maxTotalKeys = stmtExecInfo2.ExecDetail.TotalKeys - expectedSummaryElement.sumProcessedKeys += stmtExecInfo2.ExecDetail.ProcessedKeys - expectedSummaryElement.maxProcessedKeys = stmtExecInfo2.ExecDetail.ProcessedKeys + expectedSummaryElement.sumTotalKeys += stmtExecInfo2.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.maxTotalKeys = stmtExecInfo2.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.sumProcessedKeys += stmtExecInfo2.ExecDetail.ScanDetail.ProcessedKeys + expectedSummaryElement.maxProcessedKeys = stmtExecInfo2.ExecDetail.ScanDetail.ProcessedKeys expectedSummaryElement.sumGetCommitTsTime += stmtExecInfo2.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.maxGetCommitTsTime = stmtExecInfo2.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.sumPrewriteTime += stmtExecInfo2.ExecDetail.CommitDetail.PrewriteTime @@ -294,12 +298,8 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "302", - ProcessTime: 150, - WaitTime: 15, BackoffTime: 18, RequestCount: 2, - TotalKeys: 600, - ProcessedKeys: 150, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 50, PrewriteTime: 5000, @@ -318,6 +318,14 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { PrewriteRegionNum: 10, TxnRetry: 1, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 600, + ProcessedKeys: 150, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 150, + WaitTime: 15, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -336,11 +344,11 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { expectedSummaryElement.sumParseLatency += stmtExecInfo3.ParseLatency expectedSummaryElement.sumCompileLatency += stmtExecInfo3.CompileLatency expectedSummaryElement.sumNumCopTasks += int64(stmtExecInfo3.CopTasks.NumCopTasks) - expectedSummaryElement.sumProcessTime += stmtExecInfo3.ExecDetail.ProcessTime - expectedSummaryElement.sumWaitTime += stmtExecInfo3.ExecDetail.WaitTime + expectedSummaryElement.sumProcessTime += stmtExecInfo3.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.sumWaitTime += stmtExecInfo3.ExecDetail.TimeDetail.WaitTime expectedSummaryElement.sumBackoffTime += stmtExecInfo3.ExecDetail.BackoffTime - expectedSummaryElement.sumTotalKeys += stmtExecInfo3.ExecDetail.TotalKeys - expectedSummaryElement.sumProcessedKeys += stmtExecInfo3.ExecDetail.ProcessedKeys + expectedSummaryElement.sumTotalKeys += stmtExecInfo3.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.sumProcessedKeys += stmtExecInfo3.ExecDetail.ScanDetail.ProcessedKeys expectedSummaryElement.sumGetCommitTsTime += stmtExecInfo3.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.sumPrewriteTime += stmtExecInfo3.ExecDetail.CommitDetail.PrewriteTime expectedSummaryElement.sumCommitTime += stmtExecInfo3.ExecDetail.CommitDetail.CommitTime @@ -541,12 +549,8 @@ func generateAnyExecInfo() *StmtExecInfo { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "129", - ProcessTime: 500, - WaitTime: 50, BackoffTime: 80, RequestCount: 10, - TotalKeys: 1000, - ProcessedKeys: 500, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 100, PrewriteTime: 10000, @@ -565,6 +569,14 @@ func generateAnyExecInfo() *StmtExecInfo { PrewriteRegionNum: 20, TxnRetry: 2, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 1000, + ProcessedKeys: 500, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 500, + WaitTime: 50, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -600,10 +612,10 @@ func (s *testStmtSummarySuite) TestToDatum(c *C) { int64(stmtExecInfo1.ParseLatency), int64(stmtExecInfo1.ParseLatency), int64(stmtExecInfo1.CompileLatency), int64(stmtExecInfo1.CompileLatency), stmtExecInfo1.CopTasks.NumCopTasks, int64(stmtExecInfo1.CopTasks.MaxProcessTime), stmtExecInfo1.CopTasks.MaxProcessAddress, int64(stmtExecInfo1.CopTasks.MaxWaitTime), - stmtExecInfo1.CopTasks.MaxWaitAddress, int64(stmtExecInfo1.ExecDetail.ProcessTime), int64(stmtExecInfo1.ExecDetail.ProcessTime), - int64(stmtExecInfo1.ExecDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.BackoffTime), - int64(stmtExecInfo1.ExecDetail.BackoffTime), stmtExecInfo1.ExecDetail.TotalKeys, stmtExecInfo1.ExecDetail.TotalKeys, - stmtExecInfo1.ExecDetail.ProcessedKeys, stmtExecInfo1.ExecDetail.ProcessedKeys, + stmtExecInfo1.CopTasks.MaxWaitAddress, int64(stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime), int64(stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime), + int64(stmtExecInfo1.ExecDetail.TimeDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.TimeDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.BackoffTime), + int64(stmtExecInfo1.ExecDetail.BackoffTime), stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, int64(stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.CommitTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.CommitTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime), @@ -616,7 +628,7 @@ func (s *testStmtSummarySuite) TestToDatum(c *C) { stmtExecInfo1.ExecDetail.CommitDetail.TxnRetry, stmtExecInfo1.ExecDetail.CommitDetail.TxnRetry, 0, 0, 1, "txnLock:1", stmtExecInfo1.MemMax, stmtExecInfo1.MemMax, stmtExecInfo1.DiskMax, stmtExecInfo1.DiskMax, 0, 0, 0, 0, 0, stmtExecInfo1.StmtCtx.AffectedRows(), - t, t, 0, 0, stmtExecInfo1.OriginalSQL, stmtExecInfo1.PrevSQL, "plan_digest", ""} + t, t, 0, 0, 0, stmtExecInfo1.OriginalSQL, stmtExecInfo1.PrevSQL, "plan_digest", ""} match(c, datums[0], expectedDatum...) datums = s.ssMap.ToHistoryDatum(nil, true) c.Assert(len(datums), Equals, 1)