From 058e52ad7c1b477147dcb933b3f1b3b2be31e998 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 26 Feb 2021 00:24:03 +0800 Subject: [PATCH 01/85] planner, executor: reset NotNullFlag when merge schema for join (#22955) (#22958) --- executor/executor_test.go | 26 ++++++++++++++++++++++++++ planner/core/rule_column_pruning.go | 7 +++++++ 2 files changed, 33 insertions(+) diff --git a/executor/executor_test.go b/executor/executor_test.go index f5b82c73c98aa..e310abc017c96 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6647,3 +6647,29 @@ func (s *testSuite) TestIssue22201(c *C) { tk.MustQuery("SELECT HEX(WEIGHT_STRING('ab' AS char(1000000000000000000)));").Check(testkit.Rows("")) tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1301 Result of weight_string() was larger than max_allowed_packet (67108864) - truncated")) } + +func (s *testSuiteP1) TestIssue22941(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists m, mp") + tk.MustExec(`CREATE TABLE m ( + mid varchar(50) NOT NULL, + ParentId varchar(50) DEFAULT NULL, + PRIMARY KEY (mid), + KEY ind_bm_parent (ParentId,mid) + )`) + + tk.MustExec(`CREATE TABLE mp ( + mpid bigint(20) unsigned NOT NULL DEFAULT '0', + mid varchar(50) DEFAULT NULL COMMENT '模块主键', + PRIMARY KEY (mpid) + );`) + + tk.MustExec(`insert into mp values("1","1");`) + tk.MustExec(`insert into m values("0", "0");`) + rs := tk.MustQuery(`SELECT ( SELECT COUNT(1) FROM m WHERE ParentId = c.mid ) expand, bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) + rs.Check(testkit.Rows("1 1 0")) + + rs = tk.MustQuery(`SELECT bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) + rs.Check(testkit.Rows(" 1 0")) +} diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 80ea3c5775009..ea0edad342c12 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -311,6 +311,13 @@ func (p *LogicalJoin) mergeSchema() { p.schema.Append(joinCol) } else { p.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema()) + switch p.JoinType { + case LeftOuterJoin: + resetNotNullFlag(p.schema, p.children[1].Schema().Len(), p.schema.Len()) + case RightOuterJoin: + resetNotNullFlag(p.schema, 0, p.children[0].Schema().Len()) + default: + } } } From 17f20a7e578728ba78126dff311cb4d141749fc5 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 1 Mar 2021 14:38:54 +0800 Subject: [PATCH 02/85] ddl: scattering truncated tables without pre-split option (#22787) (#22872) --- ddl/ddl_api.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index f21a431346cb1..5e12df0920af8 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -3716,11 +3716,8 @@ func (d *ddl) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { } return errors.Trace(err) } - oldTblInfo := tb.Meta() - if oldTblInfo.PreSplitRegions > 0 { - if _, tb, err := d.getSchemaAndTableByIdent(ctx, ti); err == nil { - d.preSplitAndScatter(ctx, tb.Meta(), tb.Meta().GetPartitionInfo()) - } + if _, tb, err := d.getSchemaAndTableByIdent(ctx, ti); err == nil { + d.preSplitAndScatter(ctx, tb.Meta(), tb.Meta().GetPartitionInfo()) } if !config.TableLockEnabled() { From 66d97debe421a1137dd970780c410b9517f77169 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 1 Mar 2021 15:27:17 +0800 Subject: [PATCH 03/85] cherry pick #22568 to release-4.0 (#22641) Signed-off-by: ti-srebot Co-authored-by: Kenan Yao --- executor/builder.go | 8 +++--- planner/core/cache.go | 3 --- planner/core/cache_test.go | 2 +- planner/core/prepare_test.go | 50 ++++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/executor/builder.go b/executor/builder.go index 4e1fd5a1b4161..aef4c484f5f3e 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1315,11 +1315,11 @@ func (b *executorBuilder) getSnapshotTS() (uint64, error) { } snapshotTS := b.ctx.GetSessionVars().SnapshotTS - txn, err := b.ctx.Txn(true) - if err != nil { - return 0, err - } if snapshotTS == 0 { + txn, err := b.ctx.Txn(true) + if err != nil { + return 0, err + } snapshotTS = txn.StartTS() } b.snapshotTS = snapshotTS diff --git a/planner/core/cache.go b/planner/core/cache.go index 106ab9e84fe9a..9f31a08392a81 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -67,7 +67,6 @@ type pstmtPlanCacheKey struct { database string connID uint64 pstmtID uint32 - snapshot uint64 schemaVersion int64 sqlMode mysql.SQLMode timezoneOffset int @@ -90,7 +89,6 @@ func (key *pstmtPlanCacheKey) Hash() []byte { key.hash = append(key.hash, dbBytes...) key.hash = codec.EncodeInt(key.hash, int64(key.connID)) key.hash = codec.EncodeInt(key.hash, int64(key.pstmtID)) - key.hash = codec.EncodeInt(key.hash, int64(key.snapshot)) key.hash = codec.EncodeInt(key.hash, key.schemaVersion) key.hash = codec.EncodeInt(key.hash, int64(key.sqlMode)) key.hash = codec.EncodeInt(key.hash, int64(key.timezoneOffset)) @@ -134,7 +132,6 @@ func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, sch database: sessionVars.CurrentDB, connID: sessionVars.ConnectionID, pstmtID: pstmtID, - snapshot: sessionVars.SnapshotTS, schemaVersion: schemaVersion, sqlMode: sessionVars.SQLMode, timezoneOffset: timezoneOffset, diff --git a/planner/core/cache_test.go b/planner/core/cache_test.go index 8ff56ddb97e66..9f68dc27f286f 100644 --- a/planner/core/cache_test.go +++ b/planner/core/cache_test.go @@ -39,5 +39,5 @@ func (s *testCacheSuite) SetUpSuite(c *C) { func (s *testCacheSuite) TestCacheKey(c *C) { defer testleak.AfterTest(c)() key := NewPSTMTPlanCacheKey(s.ctx.GetSessionVars(), 1, 1) - c.Assert(key.Hash(), DeepEquals, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68}) + c.Assert(key.Hash(), DeepEquals, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68}) } diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 53b4a53eb07d9..92d1bef41d81f 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -893,3 +893,53 @@ func (s *testPrepareSerialSuite) TestPrepareCacheWithJoinTable(c *C) { tk.MustQuery("execute stmt using @a").Check(testkit.Rows()) tk.MustQuery("execute stmt using @b").Check(testkit.Rows("a ")) } + +func (s *testPlanSerialSuite) TestPlanCacheSnapshot(c *C) { + store, _, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + orgEnable := core.PreparedPlanCacheEnabled() + defer func() { + store.Close() + core.SetPreparedPlanCache(orgEnable) + }() + core.SetPreparedPlanCache(true) + + tk.Se, err = session.CreateSession4TestWithOpt(store, &session.Opt{ + PreparedPlanCache: kvcache.NewSimpleLRUCache(100, 0.1, math.MaxUint64), + }) + c.Assert(err, IsNil) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int)") + tk.MustExec("insert into t values (1),(2),(3),(4)") + + // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. + timeSafe := time.Now().Add(-48 * 60 * 60 * time.Second).Format("20060102-15:04:05 -0700 MST") + safePointSQL := `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('tikv_gc_safe_point', '%[1]s', '') + ON DUPLICATE KEY + UPDATE variable_value = '%[1]s'` + tk.MustExec(fmt.Sprintf(safePointSQL, timeSafe)) + + tk.MustExec("prepare stmt from 'select * from t where id=?'") + tk.MustExec("set @p = 1") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + // Record the current tso. + tk.MustExec("begin") + tso := tk.Se.GetSessionVars().TxnCtx.StartTS + tk.MustExec("rollback") + c.Assert(tso > 0, IsTrue) + // Insert one more row with id = 1. + tk.MustExec("insert into t values (1)") + + tk.MustExec(fmt.Sprintf("set @@tidb_snapshot = '%d'", tso)) + tk.MustQuery("select * from t where id = 1").Check(testkit.Rows("1")) + tk.MustQuery("execute stmt using @p").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) +} From 8406e7f58b1d5811966a6811d7b519bcfd6020b2 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 1 Mar 2021 18:12:54 +0800 Subject: [PATCH 04/85] expression: Add warning info for exprs that can not be pushed to storage layer (#22713) (#23020) --- executor/show_test.go | 10 ++++++++++ expression/expression.go | 10 ++++++++++ planner/core/physical_plan_test.go | 17 +++++++++++------ planner/core/task.go | 8 ++++++++ planner/core/testdata/plan_suite_out.json | 9 ++++++--- 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/executor/show_test.go b/executor/show_test.go index 19c65fc8e6159..38c8e0ee85837 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -144,6 +144,16 @@ func (s *testSuite5) TestShowErrors(c *C) { tk.MustQuery("show errors").Check(testutil.RowsWithSep("|", "Error|1050|Table 'test.show_errors' already exists")) } +func (s *testSuite5) TestShowWarningsForExprPushdown(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + testSQL := `create table if not exists show_warnings_expr_pushdown (a int, value date)` + tk.MustExec(testSQL) + tk.MustExec("explain select * from show_warnings_expr_pushdown where date_add(value, interval 1 day) = '2020-01-01'") + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(1)) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|Scalar function 'date_add'(signature: AddDateDatetimeInt) can not be pushed to tikv")) +} + func (s *testSuite5) TestShowGrantsPrivilege(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("create user show_grants") diff --git a/expression/expression.go b/expression/expression.go index f5fcbbf44d95f..602383d1e76c1 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1143,6 +1143,13 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType // Check whether this function can be pushed. if !canFuncBePushed(scalarFunc, storeType) { + if pc.sc.InExplainStmt { + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + pc.sc.AppendWarning(errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ") can not be pushed to " + storageName)) + } return false } @@ -1166,6 +1173,9 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType func canExprPushDown(expr Expression, pc PbConverter, storeType kv.StoreType) bool { if storeType == kv.TiFlash && expr.GetType().Tp == mysql.TypeDuration { + if pc.sc.InExplainStmt { + pc.sc.AppendWarning(errors.New("Expr '" + expr.String() + "' can not be pushed to TiFlash because it contains Duration type")) + } return false } switch x := expr.(type) { diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 4f3fa07e0cc27..91a1cff9a420f 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -880,7 +880,7 @@ func (s *testPlanSuite) TestLimitToCopHint(c *C) { output []struct { SQL string Plan []string - Warning string + Warning []string } ) @@ -897,15 +897,20 @@ func (s *testPlanSuite) TestLimitToCopHint(c *C) { warnings := tk.Se.GetSessionVars().StmtCtx.GetWarnings() s.testData.OnRecord(func() { if len(warnings) > 0 { - output[i].Warning = warnings[0].Err.Error() + output[i].Warning = make([]string, len(warnings)) + for j, warning := range warnings { + output[i].Warning[j] = warning.Err.Error() + } } }) - if output[i].Warning == "" { + if len(output[i].Warning) == 0 { c.Assert(len(warnings), Equals, 0, comment) } else { - c.Assert(len(warnings), Equals, 1, comment) - c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment) - c.Assert(warnings[0].Err.Error(), Equals, output[i].Warning, comment) + c.Assert(len(warnings), Equals, len(output[i].Warning), comment) + for j, warning := range warnings { + c.Assert(warning.Level, Equals, stmtctx.WarnLevelWarning, comment) + c.Assert(warning.Err.Error(), Equals, output[i].Warning[j], comment) + } } } } diff --git a/planner/core/task.go b/planner/core/task.go index f4ba96cc9e0e2..93ba40e83e30d 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -16,6 +16,7 @@ package core import ( "math" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" @@ -1043,6 +1044,13 @@ func CheckAggCanPushCop(sctx sessionctx.Context, aggFuncs []*aggregation.AggFunc return false } if !aggregation.CheckAggPushDown(aggFunc, storeType) { + if sc.InExplainStmt { + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + sc.AppendWarning(errors.New("Agg function '" + aggFunc.Name + "' can not be pushed to " + storageName)) + } return false } if !expression.CanExprsPushDown(sc, aggFunc.Args, client, storeType) { diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index f16a5af4324d8..df5494298bdc2 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -1456,7 +1456,7 @@ " └─Selection_13 0.83 cop[tikv] gt(test.tn.c, 50)", " └─IndexRangeScan_12 2.50 cop[tikv] table:tn, index:a(a, b, c, d) range:(1 10,1 20), keep order:false, stats:pseudo" ], - "Warning": "" + "Warning": null }, { "SQL": "select * from tn where a = 1 and b > 10 and b < 20 and c > 50 order by d limit 1", @@ -1466,7 +1466,7 @@ " └─Selection_19 0.83 cop[tikv] gt(test.tn.c, 50)", " └─IndexRangeScan_18 2.50 cop[tikv] table:tn, index:a(a, b, c, d) range:(1 10,1 20), keep order:false, stats:pseudo" ], - "Warning": "" + "Warning": null }, { "SQL": "select /*+ LIMIT_TO_COP() */ a from tn where mod(a, 2) order by a limit 1", @@ -1476,7 +1476,10 @@ " └─IndexReader_21 1.00 root index:IndexFullScan_20", " └─IndexFullScan_20 1.00 cop[tikv] table:tn, index:a(a, b, c, d) keep order:true, stats:pseudo" ], - "Warning": "[planner:1815]Optimizer Hint LIMIT_TO_COP is inapplicable" + "Warning": [ + "Scalar function 'mod'(signature: ModInt) can not be pushed to storage layer", + "[planner:1815]Optimizer Hint LIMIT_TO_COP is inapplicable" + ] } ] }, From a920405126a59415dc2872b3e2115a12dd84c38d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 1 Mar 2021 18:46:54 +0800 Subject: [PATCH 05/85] infoschema: support query partition_id from infoschema.partitions (#22240) (#22489) --- executor/infoschema_reader.go | 2 ++ executor/infoschema_reader_test.go | 5 ++++- infoschema/tables.go | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index de9189548d16d..72880529ad08f 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -684,6 +684,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx sessionctx.Context, schema nil, // PARTITION_COMMENT nil, // NODEGROUP nil, // TABLESPACE_NAME + nil, // TIDB_PARTITION_ID ) rows = append(rows, record) } else { @@ -727,6 +728,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx sessionctx.Context, schema pi.Comment, // PARTITION_COMMENT nil, // NODEGROUP nil, // TABLESPACE_NAME + pi.ID, // TIDB_PARTITION_ID ) rows = append(rows, record) } diff --git a/executor/infoschema_reader_test.go b/executor/infoschema_reader_test.go index 993226ef65963..98dff1fbddb1d 100644 --- a/executor/infoschema_reader_test.go +++ b/executor/infoschema_reader_test.go @@ -456,7 +456,10 @@ func (s *testInfoschemaTableSerialSuite) TestPartitionsTable(c *C) { tk.MustQuery("select PARTITION_NAME, TABLE_ROWS, AVG_ROW_LENGTH, DATA_LENGTH, INDEX_LENGTH from information_schema.PARTITIONS where table_name='test_partitions_1';").Check( testkit.Rows(" 3 18 54 6")) - tk.MustExec("DROP TABLE `test_partitions`;") + pid, err := strconv.Atoi(tk.MustQuery("select TIDB_PARTITION_ID from information_schema.partitions where table_name = 'test_partitions';").Rows()[0][0].(string)) + c.Assert(err, IsNil) + c.Assert(pid, Greater, 0) + tk.MustExec("drop table test_partitions") } func (s *testInfoschemaTableSuite) TestMetricTables(c *C) { diff --git a/infoschema/tables.go b/infoschema/tables.go index 41a4f2b180ba4..fb0336dd07ff4 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -475,6 +475,7 @@ var partitionsCols = []columnInfo{ {name: "PARTITION_COMMENT", tp: mysql.TypeVarchar, size: 80}, {name: "NODEGROUP", tp: mysql.TypeVarchar, size: 12}, {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, } var tableConstraintsCols = []columnInfo{ From d77d908f445deff7138662100c1c3ccfc541ff32 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 1 Mar 2021 23:50:53 +0800 Subject: [PATCH 06/85] *: introduce new API ParseWithParams (#22499) (#22548) --- session/session.go | 94 +++++- session/session_test.go | 34 +++ session/utils.go | 213 +++++++++++++ session/utils_test.go | 387 ++++++++++++++++++++++++ util/mock/context.go | 2 +- util/sqlexec/restricted_sql_executor.go | 2 +- 6 files changed, 725 insertions(+), 7 deletions(-) create mode 100644 session/utils.go create mode 100644 session/utils_test.go diff --git a/session/session.go b/session/session.go index 8912f55795c84..7c0907eb3e585 100644 --- a/session/session.go +++ b/session/session.go @@ -100,10 +100,22 @@ type Session interface { LastMessage() string // LastMessage is the info message that may be generated by last command AffectedRows() uint64 // Affected rows by latest executed stmt. // Execute is deprecated, use ExecuteStmt() instead. - Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. - ExecuteInternal(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a internal sql statement. + Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) + // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) + // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. + // It works like printf() in c, there are following format specifiers: + // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) + // 2. %%: output % + // 3. %n: for identifiers, for example ("use %n", db) + // + // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". + // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. + // This function only saves you from processing potentially unsafe parameters. + ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error) + // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. + ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error RollbackTxn(context.Context) @@ -858,7 +870,7 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) startTime := time.Now() - recordSets, err := se.Execute(ctx, sql) + recordSets, err := se.ExecuteInternal(ctx, sql) defer func() { for _, rs := range recordSets { closeErr := rs.Close() @@ -1121,13 +1133,37 @@ func (rs *execStmtResult) Close() error { return finishStmt(context.Background(), se, err, rs.sql) } -func (s *session) ExecuteInternal(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) { origin := s.sessionVars.InRestrictedSQL s.sessionVars.InRestrictedSQL = true defer func() { s.sessionVars.InRestrictedSQL = origin }() - return s.Execute(ctx, sql) + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("session.ExecuteInternal", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + logutil.Eventf(ctx, "execute: %s", sql) + } + + stmtNodes, err := s.ParseWithParams(ctx, sql, args...) + if err != nil { + return nil, err + } + if len(stmtNodes) != 1 { + return nil, errors.New("Executing multiple statements internally is not supported") + } + + rs, err := s.ExecuteStmt(ctx, stmtNodes[0]) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, err + } + + return []sqlexec.RecordSet{rs}, err } func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { @@ -1200,6 +1236,54 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) return stmts, nil } +// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error) { + sql, err := EscapeSQL(sql, args...) + if err != nil { + return nil, err + } + + internal := s.isInternal() + + var stmts []ast.StmtNode + var warns []error + var parseStartTime time.Time + if internal { + // Do no respect the settings from clients, if it is for internal usage. + // Charsets from clients may give chance injections. + // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. + parseStartTime = time.Now() + stmts, warns, err = s.ParseSQL(ctx, sql, mysql.UTF8MB4Charset, mysql.UTF8MB4DefaultCollation) + } else { + charsetInfo, collation := s.sessionVars.GetCharsetInfo() + parseStartTime = time.Now() + stmts, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) + } + if err != nil { + s.rollbackOnError(ctx) + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + if s.sessionVars.EnableRedactLog { + logutil.Logger(ctx).Debug("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } else { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } + } + return nil, util.SyntaxError(err) + } + durParse := time.Since(parseStartTime) + if s.isInternal() { + sessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + sessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + return stmts, nil +} + func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context())) diff --git a/session/session_test.go b/session/session_test.go index 96670b909aa37..53685703c18c6 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/parser" "github.com/pingcap/parser/auth" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -3568,3 +3569,36 @@ func (s *testSessionSuite2) TestRetryCommitWithSet(c *C) { tk2.MustQuery("select * from t use index(k1)").Check(testkit.Rows("1 11 101")) tk2.MustQuery("select * from t where pk = '1'").Check(testkit.Rows("1 11 101")) } + +func (s *testSessionSerialSuite) TestParseWithParams(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + se := tk.Se + + // test compatibility with ExcuteInternal + origin := se.GetSessionVars().InRestrictedSQL + se.GetSessionVars().InRestrictedSQL = true + defer func() { + se.GetSessionVars().InRestrictedSQL = origin + }() + _, err := se.ParseWithParams(context.Background(), "SELECT 4") + c.Assert(err, IsNil) + + // test charset attack + stmts, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + c.Assert(err, IsNil) + c.Assert(stmts, HasLen, 1) + + var sb strings.Builder + ctx := format.NewRestoreCtx(format.RestoreStringDoubleQuotes, &sb) + err = stmts[0].Restore(ctx) + c.Assert(err, IsNil) + c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") + + // test invalid sql + _, err = se.ParseWithParams(context.Background(), "SELECT") + c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") + + // test invalid arguments to escape + _, err = se.ParseWithParams(context.Background(), "SELECT %?") + c.Assert(err, ErrorMatches, "missing arguments.*") +} diff --git a/session/utils.go b/session/utils.go new file mode 100644 index 0000000000000..67788d5d53a4f --- /dev/null +++ b/session/utils.go @@ -0,0 +1,213 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" +) + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash will escape string into the buffer, with backslash. +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, hack.Slice(v)) +} + +// EscapeSQL will escape input arguments into the sql string, doing necessary processing. +// It works like printf() in c, there are following format specifiers: +// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) +// 2. %%: output % +// 3. %n: for identifiers, for example ("use %n", db) +// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". +// It is still your responsibility to write safe SQL. +func EscapeSQL(sql string, args ...interface{}) (string, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return "", errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.Replace(v, "`", "``", -1)...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + case []string: + buf = append(buf, '(') + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + buf = append(buf, ')') + default: + return "", errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return string(buf), nil +} diff --git a/session/utils_test.go b/session/utils_test.go new file mode 100644 index 0000000000000..f7d754418c019 --- /dev/null +++ b/session/utils_test.go @@ -0,0 +1,387 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/hack" +) + +var _ = Suite(&testUtilsSuite{}) + +type testUtilsSuite struct{} + +func (s *testUtilsSuite) TestReserveBuffer(c *C) { + res0 := reserveBuffer(nil, 0) + c.Assert(res0, HasLen, 0) + + res1 := reserveBuffer(res0, 3) + c.Assert(res1, HasLen, 3) + res1[1] = 3 + + res2 := reserveBuffer(res1, 9) + c.Assert(res2, HasLen, 12) + c.Assert(cap(res2), Equals, 15) + c.Assert(res2[:3], DeepEquals, res1) +} + +func (s *testUtilsSuite) TestEscapeBackslash(c *C) { + type TestCase struct { + name string + input []byte + output []byte + } + tests := []TestCase{ + { + name: "normal", + input: []byte("hello"), + output: []byte("hello"), + }, + { + name: "0", + input: []byte("he\x00lo"), + output: []byte("he\\0lo"), + }, + { + name: "break line", + input: []byte("he\nlo"), + output: []byte("he\\nlo"), + }, + { + name: "carry", + input: []byte("he\rlo"), + output: []byte("he\\rlo"), + }, + { + name: "substitute", + input: []byte("he\x1alo"), + output: []byte("he\\Zlo"), + }, + { + name: "single quote", + input: []byte("he'lo"), + output: []byte("he\\'lo"), + }, + { + name: "double quote", + input: []byte("he\"lo"), + output: []byte("he\\\"lo"), + }, + { + name: "back slash", + input: []byte("he\\lo"), + output: []byte("he\\\\lo"), + }, + { + name: "double escape", + input: []byte("he\x00lo\""), + output: []byte("he\\0lo\\\""), + }, + { + name: "chinese", + input: []byte("中文?"), + output: []byte("中文?"), + }, + } + for _, t := range tests { + commentf := Commentf("%s", t.name) + c.Assert(escapeBytesBackslash(nil, t.input), DeepEquals, t.output, commentf) + c.Assert(escapeStringBackslash(nil, string(hack.String(t.input))), DeepEquals, t.output, commentf) + } +} + +func (s *testUtilsSuite) TestEscapeSQL(c *C) { + type TestCase struct { + name string + input string + params []interface{} + output string + err string + } + time2, err := time.Parse("2006-01-02 15:04:05", "2018-01-23 04:03:05") + c.Assert(err, IsNil) + tests := []TestCase{ + { + name: "normal 1", + input: "select * from 1", + params: []interface{}{}, + output: "select * from 1", + err: "", + }, + { + name: "normal 2", + input: "WHERE source != 'builtin'", + params: []interface{}{}, + output: "WHERE source != 'builtin'", + err: "", + }, + { + name: "discard extra arguments", + input: "select * from 1", + params: []interface{}{4, 5, "rt"}, + output: "select * from 1", + err: "", + }, + { + name: "%? missing arguments", + input: "select %? from %?", + params: []interface{}{4}, + err: "missing arguments.*", + }, + { + name: "nil", + input: "select %?", + params: []interface{}{nil}, + output: "select NULL", + err: "", + }, + { + name: "int", + input: "select %?", + params: []interface{}{int(3)}, + output: "select 3", + err: "", + }, + { + name: "int8", + input: "select %?", + params: []interface{}{int8(4)}, + output: "select 4", + err: "", + }, + { + name: "int16", + input: "select %?", + params: []interface{}{int16(5)}, + output: "select 5", + err: "", + }, + { + name: "int32", + input: "select %?", + params: []interface{}{int32(6)}, + output: "select 6", + err: "", + }, + { + name: "int64", + input: "select %?", + params: []interface{}{int64(7)}, + output: "select 7", + err: "", + }, + { + name: "uint", + input: "select %?", + params: []interface{}{uint(8)}, + output: "select 8", + err: "", + }, + { + name: "uint8", + input: "select %?", + params: []interface{}{uint8(9)}, + output: "select 9", + err: "", + }, + { + name: "uint16", + input: "select %?", + params: []interface{}{uint16(10)}, + output: "select 10", + err: "", + }, + { + name: "uint32", + input: "select %?", + params: []interface{}{uint32(11)}, + output: "select 11", + err: "", + }, + { + name: "uint64", + input: "select %?", + params: []interface{}{uint64(12)}, + output: "select 12", + err: "", + }, + { + name: "float32", + input: "select %?", + params: []interface{}{float32(0.13)}, + output: "select 0.13", + err: "", + }, + { + name: "float64", + input: "select %?", + params: []interface{}{float64(0.14)}, + output: "select 0.14", + err: "", + }, + { + name: "bool on", + input: "select %?", + params: []interface{}{true}, + output: "select 1", + err: "", + }, + { + name: "bool off", + input: "select %?", + params: []interface{}{false}, + output: "select 0", + err: "", + }, + { + name: "time 0", + input: "select %?", + params: []interface{}{time.Time{}}, + output: "select '0000-00-00'", + err: "", + }, + { + name: "time 1", + input: "select %?", + params: []interface{}{time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)}, + output: "select '2019-01-01 00:00:00'", + err: "", + }, + { + name: "time 2", + input: "select %?", + params: []interface{}{time2}, + output: "select '2018-01-23 04:03:05'", + err: "", + }, + { + name: "time 3", + input: "select %?", + params: []interface{}{time.Unix(0, 888888888)}, + output: "select '1970-01-01 08:00:00.888888'", + err: "", + }, + { + name: "empty byte slice1", + input: "select %?", + params: []interface{}{[]byte(nil)}, + output: "select NULL", + err: "", + }, + { + name: "empty byte slice2", + input: "select %?", + params: []interface{}{[]byte{}}, + output: "select _binary''", + err: "", + }, + { + name: "byte slice", + input: "select %?", + params: []interface{}{[]byte{2, 3}}, + output: "select _binary'\x02\x03'", + err: "", + }, + { + name: "string", + input: "select %?", + params: []interface{}{"33"}, + output: "select '33'", + }, + { + name: "string slice", + input: "select %?", + params: []interface{}{[]string{"33", "44"}}, + output: "select ('33','44')", + }, + { + name: "raw json", + input: "select %?", + params: []interface{}{json.RawMessage(`{"h": "hello"}`)}, + output: "select '{\\\"h\\\": \\\"hello\\\"}'", + }, + { + name: "unsupported args", + input: "select %?", + params: []interface{}{make(chan byte)}, + err: "unsupported 1-th argument.*", + }, + { + name: "mixed arguments", + input: "select %?, %?, %?", + params: []interface{}{"33", 44, time.Time{}}, + output: "select '33', 44, '0000-00-00'", + }, + { + name: "simple injection", + input: "select %?", + params: []interface{}{"0; drop database"}, + output: "select '0; drop database'", + }, + { + name: "identifier, wrong arg", + input: "use %n", + params: []interface{}{3}, + err: "expect a string identifier.*", + }, + { + name: "identifier", + input: "use %n", + params: []interface{}{"table`"}, + output: "use `table```", + err: "", + }, + { + name: "%n missing arguments", + input: "use %n", + params: []interface{}{}, + err: "missing arguments.*", + }, + { + name: "% escape", + input: "select * from t where val = '%%?'", + params: []interface{}{}, + output: "select * from t where val = '%?'", + err: "", + }, + { + name: "unknown specifier", + input: "%v", + params: []interface{}{}, + output: "%v", + err: "", + }, + { + name: "truncated specifier ", + input: "rv %", + params: []interface{}{}, + output: "rv %", + err: "", + }, + } + for _, t := range tests { + comment := Commentf("%s", t.name) + escaped, err := EscapeSQL(t.input, t.params...) + if t.err == "" { + c.Assert(err, IsNil, comment) + c.Assert(escaped, Equals, t.output, comment) + } else { + c.Assert(err, NotNil, comment) + c.Assert(err, ErrorMatches, t.err, comment) + } + } +} diff --git a/util/mock/context.go b/util/mock/context.go index 7dded0b330a0b..13c878c822bed 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -62,7 +62,7 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, } // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. -func (c *Context) ExecuteInternal(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) { +func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) { return nil, errors.Errorf("Not Support.") } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index ce39db7fd00a7..6ea589ec5ff5d 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -51,7 +51,7 @@ type RestrictedSQLExecutor interface { type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. - ExecuteInternal(ctx context.Context, sql string) ([]RecordSet, error) + ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error) } // SQLParser is an interface provides parsing sql statement. From 4cf3284e7ffb0e265ea9d0ffa61e42f47b7f8b12 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 2 Mar 2021 14:20:54 +0800 Subject: [PATCH 07/85] *: move new api out of session package (#22591) (#22656) --- session/session.go | 20 +- session/session_test.go | 15 +- util/sqlexec/utils.go | 260 ++++++++++++++++++++++ util/sqlexec/utils_test.go | 430 +++++++++++++++++++++++++++++++++++++ 4 files changed, 713 insertions(+), 12 deletions(-) create mode 100644 util/sqlexec/utils.go create mode 100644 util/sqlexec/utils_test.go diff --git a/session/session.go b/session/session.go index 7c0907eb3e585..1a8bcc7b21b59 100644 --- a/session/session.go +++ b/session/session.go @@ -113,7 +113,7 @@ type Session interface { // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. // This function only saves you from processing potentially unsafe parameters. - ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error) + ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) String() string // String is used to debug. @@ -1147,15 +1147,12 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter logutil.Eventf(ctx, "execute: %s", sql) } - stmtNodes, err := s.ParseWithParams(ctx, sql, args...) + stmt, err := s.ParseWithParams(ctx, sql, args...) if err != nil { return nil, err } - if len(stmtNodes) != 1 { - return nil, errors.New("Executing multiple statements internally is not supported") - } - rs, err := s.ExecuteStmt(ctx, stmtNodes[0]) + rs, err := s.ExecuteStmt(ctx, stmt) if err != nil { s.sessionVars.StmtCtx.AppendError(err) } @@ -1237,8 +1234,10 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) } // ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. -func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error) { - sql, err := EscapeSQL(sql, args...) +// Note that it will not do escaping if no variable arguments are passed. +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { + var err error + sql, err = sqlexec.EscapeSQL(sql, args...) if err != nil { return nil, err } @@ -1259,6 +1258,9 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter parseStartTime = time.Now() stmts, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) } + if len(stmts) != 1 { + err = errors.New("run multiple statements internally is not supported") + } if err != nil { s.rollbackOnError(ctx) // Only print log message when this SQL is from the user. @@ -1281,7 +1283,7 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter for _, warn := range warns { s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) } - return stmts, nil + return stmts[0], nil } func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { diff --git a/session/session_test.go b/session/session_test.go index 53685703c18c6..9a6b78c46beb5 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -3584,13 +3584,12 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(err, IsNil) // test charset attack - stmts, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) - c.Assert(stmts, HasLen, 1) var sb strings.Builder ctx := format.NewRestoreCtx(format.RestoreStringDoubleQuotes, &sb) - err = stmts[0].Restore(ctx) + err = stmt.Restore(ctx) c.Assert(err, IsNil) c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") @@ -3601,4 +3600,14 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { // test invalid arguments to escape _, err = se.ParseWithParams(context.Background(), "SELECT %?") c.Assert(err, ErrorMatches, "missing arguments.*") + + // test noescape + stmt, err = se.ParseWithParams(context.TODO(), "SELECT 3") + c.Assert(err, IsNil) + + sb.Reset() + ctx = format.NewRestoreCtx(0, &sb) + err = stmt.Restore(ctx) + c.Assert(err, IsNil) + c.Assert(sb.String(), Equals, "SELECT 3") } diff --git a/util/sqlexec/utils.go b/util/sqlexec/utils.go new file mode 100644 index 0000000000000..1ffc29b72d8e0 --- /dev/null +++ b/util/sqlexec/utils.go @@ -0,0 +1,260 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "io" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" +) + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash will escape string into the buffer, with backslash. +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, hack.Slice(v)) +} + +// escapeSQL is the internal impl of EscapeSQL and FormatSQL. +func escapeSQL(sql string, args ...interface{}) ([]byte, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return nil, errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.Replace(v, "`", "``", -1)...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, v) + buf = append(buf, '\'') + case []string: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + case []float32: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32) + } + case []float64: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, k, 'g', -1, 64) + } + default: + return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return buf, nil +} + +// EscapeSQL will escape input arguments into the sql string, doing necessary processing. +// It works like printf() in c, there are following format specifiers: +// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) +// 2. %%: output % +// 3. %n: for identifiers, for example ("use %n", db) +// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". +// It is still your responsibility to write safe SQL. +func EscapeSQL(sql string, args ...interface{}) (string, error) { + str, err := escapeSQL(sql, args...) + return string(str), err +} + +// MustEscapeSQL is an helper around EscapeSQL. The error returned from escapeSQL can be avoided statically if you do not pass interface{}. +func MustEscapeSQL(sql string, args ...interface{}) string { + r, err := EscapeSQL(sql, args...) + if err != nil { + panic(err) + } + return r +} + +// FormatSQL is the io.Writer version of EscapeSQL. Please refer to EscapeSQL for details. +func FormatSQL(w io.Writer, sql string, args ...interface{}) error { + buf, err := escapeSQL(sql, args...) + if err != nil { + return err + } + _, err = w.Write(buf) + return err +} + +// MustFormatSQL is an helper around FormatSQL, like MustEscapeSQL. But it asks that the writer must be strings.Builder, +// which will not return error when w.Write(...). +func MustFormatSQL(w *strings.Builder, sql string, args ...interface{}) { + err := FormatSQL(w, sql, args...) + if err != nil { + panic(err) + } +} diff --git a/util/sqlexec/utils_test.go b/util/sqlexec/utils_test.go new file mode 100644 index 0000000000000..a8a912a33978f --- /dev/null +++ b/util/sqlexec/utils_test.go @@ -0,0 +1,430 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlexec + +import ( + "encoding/json" + "strings" + "testing" + "time" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testUtilsSuite{}) + +type testUtilsSuite struct{} + +func (s *testUtilsSuite) TestReserveBuffer(c *C) { + res0 := reserveBuffer(nil, 0) + c.Assert(res0, HasLen, 0) + + res1 := reserveBuffer(res0, 3) + c.Assert(res1, HasLen, 3) + res1[1] = 3 + + res2 := reserveBuffer(res1, 9) + c.Assert(res2, HasLen, 12) + c.Assert(cap(res2), Equals, 15) + c.Assert(res2[:3], DeepEquals, res1) +} + +func (s *testUtilsSuite) TestEscapeBackslash(c *C) { + type TestCase struct { + name string + input []byte + output []byte + } + tests := []TestCase{ + { + name: "normal", + input: []byte("hello"), + output: []byte("hello"), + }, + { + name: "0", + input: []byte("he\x00lo"), + output: []byte("he\\0lo"), + }, + { + name: "break line", + input: []byte("he\nlo"), + output: []byte("he\\nlo"), + }, + { + name: "carry", + input: []byte("he\rlo"), + output: []byte("he\\rlo"), + }, + { + name: "substitute", + input: []byte("he\x1alo"), + output: []byte("he\\Zlo"), + }, + { + name: "single quote", + input: []byte("he'lo"), + output: []byte("he\\'lo"), + }, + { + name: "double quote", + input: []byte("he\"lo"), + output: []byte("he\\\"lo"), + }, + { + name: "back slash", + input: []byte("he\\lo"), + output: []byte("he\\\\lo"), + }, + { + name: "double escape", + input: []byte("he\x00lo\""), + output: []byte("he\\0lo\\\""), + }, + { + name: "chinese", + input: []byte("中文?"), + output: []byte("中文?"), + }, + } + for _, t := range tests { + commentf := Commentf("%s", t.name) + c.Assert(escapeBytesBackslash(nil, t.input), DeepEquals, t.output, commentf) + c.Assert(escapeStringBackslash(nil, string(t.input)), DeepEquals, t.output, commentf) + } +} + +func (s *testUtilsSuite) TestEscapeSQL(c *C) { + type TestCase struct { + name string + input string + params []interface{} + output string + err string + } + time2, err := time.Parse("2006-01-02 15:04:05", "2018-01-23 04:03:05") + c.Assert(err, IsNil) + tests := []TestCase{ + { + name: "normal 1", + input: "select * from 1", + params: []interface{}{}, + output: "select * from 1", + err: "", + }, + { + name: "normal 2", + input: "WHERE source != 'builtin'", + params: []interface{}{}, + output: "WHERE source != 'builtin'", + err: "", + }, + { + name: "discard extra arguments", + input: "select * from 1", + params: []interface{}{4, 5, "rt"}, + output: "select * from 1", + err: "", + }, + { + name: "%? missing arguments", + input: "select %? from %?", + params: []interface{}{4}, + err: "missing arguments.*", + }, + { + name: "nil", + input: "select %?", + params: []interface{}{nil}, + output: "select NULL", + err: "", + }, + { + name: "int", + input: "select %?", + params: []interface{}{int(3)}, + output: "select 3", + err: "", + }, + { + name: "int8", + input: "select %?", + params: []interface{}{int8(4)}, + output: "select 4", + err: "", + }, + { + name: "int16", + input: "select %?", + params: []interface{}{int16(5)}, + output: "select 5", + err: "", + }, + { + name: "int32", + input: "select %?", + params: []interface{}{int32(6)}, + output: "select 6", + err: "", + }, + { + name: "int64", + input: "select %?", + params: []interface{}{int64(7)}, + output: "select 7", + err: "", + }, + { + name: "uint", + input: "select %?", + params: []interface{}{uint(8)}, + output: "select 8", + err: "", + }, + { + name: "uint8", + input: "select %?", + params: []interface{}{uint8(9)}, + output: "select 9", + err: "", + }, + { + name: "uint16", + input: "select %?", + params: []interface{}{uint16(10)}, + output: "select 10", + err: "", + }, + { + name: "uint32", + input: "select %?", + params: []interface{}{uint32(11)}, + output: "select 11", + err: "", + }, + { + name: "uint64", + input: "select %?", + params: []interface{}{uint64(12)}, + output: "select 12", + err: "", + }, + { + name: "float32", + input: "select %?", + params: []interface{}{float32(0.13)}, + output: "select 0.13", + err: "", + }, + { + name: "float64", + input: "select %?", + params: []interface{}{float64(0.14)}, + output: "select 0.14", + err: "", + }, + { + name: "bool on", + input: "select %?", + params: []interface{}{true}, + output: "select 1", + err: "", + }, + { + name: "bool off", + input: "select %?", + params: []interface{}{false}, + output: "select 0", + err: "", + }, + { + name: "time 0", + input: "select %?", + params: []interface{}{time.Time{}}, + output: "select '0000-00-00'", + err: "", + }, + { + name: "time 1", + input: "select %?", + params: []interface{}{time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)}, + output: "select '2019-01-01 00:00:00'", + err: "", + }, + { + name: "time 2", + input: "select %?", + params: []interface{}{time2}, + output: "select '2018-01-23 04:03:05'", + err: "", + }, + { + name: "time 3", + input: "select %?", + params: []interface{}{time.Unix(0, 888888888).UTC()}, + output: "select '1970-01-01 00:00:00.888888'", + err: "", + }, + { + name: "empty byte slice1", + input: "select %?", + params: []interface{}{[]byte(nil)}, + output: "select NULL", + err: "", + }, + { + name: "empty byte slice2", + input: "select %?", + params: []interface{}{[]byte{}}, + output: "select _binary''", + err: "", + }, + { + name: "byte slice", + input: "select %?", + params: []interface{}{[]byte{2, 3}}, + output: "select _binary'\x02\x03'", + err: "", + }, + { + name: "string", + input: "select %?", + params: []interface{}{"33"}, + output: "select '33'", + }, + { + name: "string slice", + input: "select %?", + params: []interface{}{[]string{"33", "44"}}, + output: "select '33','44'", + }, + { + name: "raw json", + input: "select %?", + params: []interface{}{json.RawMessage(`{"h": "hello"}`)}, + output: "select '{\\\"h\\\": \\\"hello\\\"}'", + }, + { + name: "unsupported args", + input: "select %?", + params: []interface{}{make(chan byte)}, + err: "unsupported 1-th argument.*", + }, + { + name: "mixed arguments", + input: "select %?, %?, %?", + params: []interface{}{"33", 44, time.Time{}}, + output: "select '33', 44, '0000-00-00'", + }, + { + name: "simple injection", + input: "select %?", + params: []interface{}{"0; drop database"}, + output: "select '0; drop database'", + }, + { + name: "identifier, wrong arg", + input: "use %n", + params: []interface{}{3}, + err: "expect a string identifier.*", + }, + { + name: "identifier", + input: "use %n", + params: []interface{}{"table`"}, + output: "use `table```", + err: "", + }, + { + name: "%n missing arguments", + input: "use %n", + params: []interface{}{}, + err: "missing arguments.*", + }, + { + name: "% escape", + input: "select * from t where val = '%%?'", + params: []interface{}{}, + output: "select * from t where val = '%?'", + err: "", + }, + { + name: "unknown specifier", + input: "%v", + params: []interface{}{}, + output: "%v", + err: "", + }, + { + name: "truncated specifier ", + input: "rv %", + params: []interface{}{}, + output: "rv %", + err: "", + }, + { + name: "float32 slice", + input: "select %?", + params: []interface{}{[]float32{33.1, 0.44}}, + output: "select 33.1,0.44", + }, + { + name: "float64 slice", + input: "select %?", + params: []interface{}{[]float64{55.2, 0.66}}, + output: "select 55.2,0.66", + }, + } + for _, t := range tests { + comment := Commentf("%s", t.name) + r3 := new(strings.Builder) + r1, e1 := escapeSQL(t.input, t.params...) + r2, e2 := EscapeSQL(t.input, t.params...) + e3 := FormatSQL(r3, t.input, t.params...) + if t.err == "" { + c.Assert(e1, IsNil, comment) + c.Assert(string(r1), Equals, t.output, comment) + c.Assert(e2, IsNil, comment) + c.Assert(r2, Equals, t.output, comment) + c.Assert(e3, IsNil, comment) + c.Assert(r3.String(), Equals, t.output, comment) + } else { + c.Assert(e1, NotNil, comment) + c.Assert(e1, ErrorMatches, t.err, comment) + c.Assert(e2, NotNil, comment) + c.Assert(e2, ErrorMatches, t.err, comment) + c.Assert(e3, NotNil, comment) + c.Assert(e3, ErrorMatches, t.err, comment) + } + } +} + +func (s *testUtilsSuite) TestMustUtils(c *C) { + c.Assert(func() { + MustEscapeSQL("%?") + }, PanicMatches, "missing arguments.*") + + c.Assert(func() { + sql := new(strings.Builder) + MustFormatSQL(sql, "%?") + }, PanicMatches, "missing arguments.*") + + sql := new(strings.Builder) + MustFormatSQL(sql, "t") + MustEscapeSQL("tt") +} From 504936434eb76bf3be4d3ec2e5e071d56bfe8d5f Mon Sep 17 00:00:00 2001 From: Zhou Kunqin <25057648+time-and-fate@users.noreply.github.com> Date: Wed, 3 Mar 2021 10:54:17 +0800 Subject: [PATCH 08/85] planner, expression: fix error when using IN combined with subquery (#22080) (#23047) --- expression/builtin_other.go | 123 ++++++++++++---------- expression/builtin_other_vec_generated.go | 58 ++++++---- expression/generator/other_vec.go | 9 +- expression/scalar_function.go | 23 ---- planner/core/integration_test.go | 9 ++ 5 files changed, 120 insertions(+), 102 deletions(-) diff --git a/expression/builtin_other.go b/expression/builtin_other.go index 513dab9f206ad..faa41cd0a256a 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -155,8 +155,10 @@ func (c *inFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) type baseInSig struct { baseBuiltinFunc - nonConstArgs []Expression - hasNull bool + // nonConstArgsIdx stores the indices of non-constant args in the baseBuiltinFunc.args (the first arg is not included). + // It works with builtinInXXXSig.hashset to accelerate 'eval'. + nonConstArgsIdx []int + hasNull bool } // builtinInIntSig see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_in @@ -167,7 +169,7 @@ type builtinInIntSig struct { } func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[int64]bool, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -181,7 +183,7 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error } b.hashSet[val] = mysql.HasUnsignedFlag(b.args[i].GetType().Flag) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } return nil @@ -190,10 +192,8 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error func (b *builtinInIntSig) Clone() builtinFunc { newSig := &builtinInIntSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -206,9 +206,8 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) { } isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if isUnsigned, ok := b.hashSet[arg0]; ok { if (isUnsigned0 && isUnsigned) || (!isUnsigned0 && !isUnsigned) { return 1, false, nil @@ -217,10 +216,14 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) { return 1, false, nil } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalInt(b.ctx, row) if err != nil { return 0, true, err @@ -258,7 +261,7 @@ type builtinInStringSig struct { } func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewStringSet() collator := collate.GetCollator(b.collation) for i := 1; i < len(b.args); i++ { @@ -273,7 +276,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er } b.hashSet.Insert(string(collator.Key(val))) // should do memory copy here } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -283,10 +286,8 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er func (b *builtinInStringSig) Clone() builtinFunc { newSig := &builtinInStringSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -298,17 +299,20 @@ func (b *builtinInStringSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull0, err } - args := b.args + args := b.args[1:] collator := collate.GetCollator(b.collation) if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(string(collator.Key(arg0))) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalString(b.ctx, row) if err != nil { return 0, true, err @@ -331,7 +335,7 @@ type builtinInRealSig struct { } func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewFloat64Set() for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -345,7 +349,7 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro } b.hashSet.Insert(val) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -355,10 +359,8 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro func (b *builtinInRealSig) Clone() builtinFunc { newSig := &builtinInRealSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -369,15 +371,19 @@ func (b *builtinInRealSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(arg0) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalReal(b.ctx, row) if err != nil { return 0, true, err @@ -400,7 +406,7 @@ type builtinInDecimalSig struct { } func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = set.NewStringSet() for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -418,7 +424,7 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e } b.hashSet.Insert(string(key)) } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -428,10 +434,8 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e func (b *builtinInDecimalSig) Clone() builtinFunc { newSig := &builtinInDecimalSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -443,20 +447,23 @@ func (b *builtinInDecimalSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, isNull0, err } - args := b.args + args := b.args[1:] key, err := arg0.ToHashKey() if err != nil { return 0, true, err } if len(b.hashSet) != 0 { - args = b.nonConstArgs if b.hashSet.Exist(string(key)) { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalDecimal(b.ctx, row) if err != nil { return 0, true, err @@ -479,7 +486,7 @@ type builtinInTimeSig struct { } func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[types.CoreTime]struct{}, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -493,7 +500,7 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro } b.hashSet[val.CoreTime()] = struct{}{} } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -503,10 +510,8 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro func (b *builtinInTimeSig) Clone() builtinFunc { newSig := &builtinInTimeSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -517,15 +522,19 @@ func (b *builtinInTimeSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if _, ok := b.hashSet[arg0.CoreTime()]; ok { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalTime(b.ctx, row) if err != nil { return 0, true, err @@ -548,7 +557,7 @@ type builtinInDurationSig struct { } func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) error { - b.nonConstArgs = []Expression{b.args[0]} + b.nonConstArgsIdx = make([]int, 0) b.hashSet = make(map[time.Duration]struct{}, len(b.args)-1) for i := 1; i < len(b.args); i++ { if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) { @@ -562,7 +571,7 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) } b.hashSet[val.Duration] = struct{}{} } else { - b.nonConstArgs = append(b.nonConstArgs, b.args[i]) + b.nonConstArgsIdx = append(b.nonConstArgsIdx, i) } } @@ -572,10 +581,8 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) func (b *builtinInDurationSig) Clone() builtinFunc { newSig := &builtinInDurationSig{} newSig.cloneFrom(&b.baseBuiltinFunc) - newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs)) - for _, arg := range b.nonConstArgs { - newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone()) - } + newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx)) + copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx) newSig.hashSet = b.hashSet newSig.hasNull = b.hasNull return newSig @@ -586,15 +593,19 @@ func (b *builtinInDurationSig) evalInt(row chunk.Row) (int64, bool, error) { if isNull0 || err != nil { return 0, isNull0, err } - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs if _, ok := b.hashSet[arg0.Duration]; ok { return 1, false, nil } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } + hasNull := b.hasNull - for _, arg := range args[1:] { + for _, arg := range args { evaledArg, isNull, err := arg.EvalDuration(b.ctx, row) if err != nil { return 0, true, err diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index e44f2f6759ca0..0cdf900f0793d 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -53,9 +53,8 @@ func (b *builtinInIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e } isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -73,9 +72,13 @@ func (b *builtinInIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e } } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalInt(b.ctx, input, buf1); err != nil { return err } @@ -153,10 +156,9 @@ func (b *builtinInStringSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { collator := collate.GetCollator(b.collation) - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -168,9 +170,13 @@ func (b *builtinInStringSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalString(b.ctx, input, buf1); err != nil { return err } @@ -232,9 +238,8 @@ func (b *builtinInDecimalSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -250,9 +255,13 @@ func (b *builtinInDecimalSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalDecimal(b.ctx, input, buf1); err != nil { return err } @@ -319,9 +328,8 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -333,9 +341,13 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalReal(b.ctx, input, buf1); err != nil { return err } @@ -399,9 +411,8 @@ func (b *builtinInTimeSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -413,9 +424,13 @@ func (b *builtinInTimeSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalTime(b.ctx, input, buf1); err != nil { return err } @@ -479,9 +494,8 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu } } var compareResult int - args := b.args + args := b.args[1:] if len(b.hashSet) != 0 { - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -493,9 +507,13 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu result.SetNull(i, false) } } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalDuration(b.ctx, input, buf1); err != nil { return err } @@ -553,9 +571,9 @@ func (b *builtinInJSONSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } hasNull := make([]bool, n) var compareResult int - args := b.args + args := b.args[1:] - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEvalJSON(b.ctx, input, buf1); err != nil { return err } diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index e04c629a53c3b..a1d574c80b901 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -144,13 +144,12 @@ func (b *{{.SigName}}) vecEvalInt(input *chunk.Chunk, result *chunk.Column) erro isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag) {{- end }} var compareResult int - args := b.args + args := b.args[1:] {{- if not $InputJSON}} if len(b.hashSet) != 0 { {{- if $InputString }} collator := collate.GetCollator(b.collation) {{- end }} - args = b.nonConstArgs for i := 0; i < n; i++ { if buf0.IsNull(i) { hasNull[i] = true @@ -202,10 +201,14 @@ func (b *{{.SigName}}) vecEvalInt(input *chunk.Chunk, result *chunk.Column) erro {{- end }} {{- end }} } + args = args[:0] + for _, i := range b.nonConstArgsIdx { + args = append(args, b.args[i]) + } } {{- end }} - for j := 1; j < len(args); j++ { + for j := 0; j < len(args); j++ { if err := args[j].VecEval{{ .Input.TypeName }}(b.ctx, input, buf1); err != nil { return err } diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 9137e922df5e9..47aa3f4da5037 100755 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -428,29 +428,6 @@ func (sf *ScalarFunction) ResolveIndices(schema *Schema) (Expression, error) { } func (sf *ScalarFunction) resolveIndices(schema *Schema) error { - if sf.FuncName.L == ast.In { - args := []Expression{} - switch inFunc := sf.Function.(type) { - case *builtinInIntSig: - args = inFunc.nonConstArgs - case *builtinInStringSig: - args = inFunc.nonConstArgs - case *builtinInTimeSig: - args = inFunc.nonConstArgs - case *builtinInDurationSig: - args = inFunc.nonConstArgs - case *builtinInRealSig: - args = inFunc.nonConstArgs - case *builtinInDecimalSig: - args = inFunc.nonConstArgs - } - for _, arg := range args { - err := arg.resolveIndices(schema) - if err != nil { - return err - } - } - } for _, arg := range sf.GetArgs() { err := arg.resolveIndices(schema) if err != nil { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 67d93f910824f..5779082c2a551 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1823,3 +1823,12 @@ func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) } } + +func (s *testIntegrationSuite) TestIssue22071(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t (a int);") + tk.MustExec("insert into t values(1),(2),(5)") + tk.MustQuery("select n in (1,2) from (select a in (1,2) as n from t) g;").Sort().Check(testkit.Rows("0", "1", "1")) + tk.MustQuery("select n in (1,n) from (select a in (1,2) as n from t) g;").Check(testkit.Rows("1", "1", "1")) +} From ae010ce5c5670a0656650f214ee0b1397c2529d4 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Wed, 3 Mar 2021 23:22:54 +0800 Subject: [PATCH 09/85] *: refactor the RestrictedSQLExecutor interface (#22579) (#22621) --- domain/domain.go | 9 ++- server/sql_info_fetcher.go | 2 +- session/session.go | 81 ++++++++++++++++++++++--- session/session_test.go | 11 ++-- store/tikv/gcworker/gc_worker.go | 18 +++--- util/admin/admin.go | 36 ++++++++--- util/gcutil/gcutil.go | 40 +++++++----- util/sqlexec/restricted_sql_executor.go | 37 +++++++++++ 8 files changed, 183 insertions(+), 51 deletions(-) diff --git a/domain/domain.go b/domain/domain.go index 649b3e7b1287d..1742102cc9445 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1215,9 +1215,12 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { } } // update locally - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`FLUSH PRIVILEGES`) - if err != nil { - logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil { + _, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) + if err != nil { + logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) + } } } diff --git a/server/sql_info_fetcher.go b/server/sql_info_fetcher.go index 34236f8eabe7d..76ba5d6682341 100644 --- a/server/sql_info_fetcher.go +++ b/server/sql_info_fetcher.go @@ -88,7 +88,7 @@ func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request) timeoutString := r.FormValue("timeout") curDB := strings.ToLower(r.FormValue("current_db")) if curDB != "" { - _, err = sh.s.Execute(reqCtx, fmt.Sprintf("use %v", curDB)) + _, err = sh.s.ExecuteInternal(reqCtx, "use %n", curDB) if err != nil { serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err)) return diff --git a/session/session.go b/session/session.go index 1a8bcc7b21b59..faa6ec06dfcc8 100644 --- a/session/session.go +++ b/session/session.go @@ -104,16 +104,6 @@ type Session interface { ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) - // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. - // It works like printf() in c, there are following format specifiers: - // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) - // 2. %%: output % - // 3. %n: for identifiers, for example ("use %n", db) - // - // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". - // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. - // This function only saves you from processing potentially unsafe parameters. - ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) String() string // String is used to debug. @@ -1286,6 +1276,77 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter return stmts[0], nil } +// ExecRestrictedStmt implements RestrictedSQLExecutor interface. +func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + var execOption sqlexec.ExecOption + for _, opt := range opts { + opt(&execOption) + } + // Use special session to execute the sql. + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, err + } + defer s.sysSessionPool().Put(tmp) + se := tmp.(*session) + + startTime := time.Now() + // The special session will share the `InspectionTableCache` with current session + // if the current session in inspection mode. + if cache := s.sessionVars.InspectionTableCache; cache != nil { + se.sessionVars.InspectionTableCache = cache + defer func() { se.sessionVars.InspectionTableCache = nil }() + } + defer func() { + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + }() + + if execOption.SnapshotTS != 0 { + se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + defer func() { + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + se.sessionVars.SnapshotInfoschema = nil + }() + } + + metrics.SessionRestrictedSQLCounter.Inc() + + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err + } + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) + return rows, rs.Fields(), err +} + func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context())) diff --git a/session/session_test.go b/session/session_test.go index 9a6b78c46beb5..a32e9e05349ed 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -3573,6 +3573,7 @@ func (s *testSessionSuite2) TestRetryCommitWithSet(c *C) { func (s *testSessionSerialSuite) TestParseWithParams(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) se := tk.Se + exec := se.(sqlexec.RestrictedSQLExecutor) // test compatibility with ExcuteInternal origin := se.GetSessionVars().InRestrictedSQL @@ -3580,11 +3581,11 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { defer func() { se.GetSessionVars().InRestrictedSQL = origin }() - _, err := se.ParseWithParams(context.Background(), "SELECT 4") + _, err := exec.ParseWithParams(context.Background(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmt, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder @@ -3594,15 +3595,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") // test invalid sql - _, err = se.ParseWithParams(context.Background(), "SELECT") + _, err = exec.ParseWithParams(context.Background(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = se.ParseWithParams(context.Background(), "SELECT %?") + _, err = exec.ParseWithParams(context.Background(), "SELECT %?") c.Assert(err, ErrorMatches, "missing arguments.*") // test noescape - stmt, err = se.ParseWithParams(context.TODO(), "SELECT 3") + stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") c.Assert(err, IsNil) sb.Reset() diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 0ceea6292019b..b606dec2b1fac 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -285,7 +285,7 @@ func (w *GCWorker) prepare() (bool, uint64, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, 0, errors.Trace(err) } @@ -1599,7 +1599,7 @@ func (w *GCWorker) checkLeader() (bool, error) { defer se.Close() ctx := context.Background() - _, err := se.Execute(ctx, "BEGIN") + _, err := se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1624,7 +1624,7 @@ func (w *GCWorker) checkLeader() (bool, error) { se.RollbackTxn(ctx) - _, err = se.Execute(ctx, "BEGIN") + _, err = se.ExecuteInternal(ctx, "BEGIN") if err != nil { return false, errors.Trace(err) } @@ -1732,8 +1732,7 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { ctx := context.Background() se := createSession(w.store) defer se.Close() - stmt := fmt.Sprintf(`SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name='%s' FOR UPDATE`, key) - rs, err := se.Execute(ctx, stmt) + rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) if len(rs) > 0 { defer terror.Call(rs[0].Close) } @@ -1758,13 +1757,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { } func (w *GCWorker) saveValueToSysTable(key, value string) error { - stmt := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') + const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'`, - key, value, gcVariableComments[key]) + UPDATE variable_value = %?, comment = %?` se := createSession(w.store) defer se.Close() - _, err := se.Execute(context.Background(), stmt) + _, err := se.ExecuteInternal(context.Background(), stmt, + key, value, gcVariableComments[key], + value, gcVariableComments[key]) logutil.BgLogger().Debug("[gc worker] save kv", zap.String("key", key), zap.String("value", value), diff --git a/util/admin/admin.go b/util/admin/admin.go index a5f46de92f9aa..7419a1ab6d0b4 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -16,12 +16,12 @@ package admin import ( "context" "encoding/json" - "fmt" "math" "sort" "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/errno" @@ -288,13 +288,13 @@ type RecordData struct { Values []types.Datum } -func getCount(ctx sessionctx.Context, sql string) (int64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) +func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) { + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return 0, errors.Trace(err) } if len(rows) != 1 { - return 0, errors.Errorf("can not get count, sql %s result rows %d", sql, len(rows)) + return 0, errors.Errorf("can not get count, rows count = %d", len(rows)) } return rows[0].GetInt64(0), nil } @@ -313,14 +313,34 @@ const ( // otherwise it returns an error and the corresponding index's offset. func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) (byte, int, error) { // Add `` for some names like `table name`. - sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName) - tblCnt, err := getCount(ctx, sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) + if err != nil { + return 0, 0, errors.Trace(err) + } + + var snapshot uint64 + txn, err := ctx.Txn(false) + if err != nil { + return 0, 0, err + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = ctx.GetSessionVars().SnapshotTS + } + + tblCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, 0, errors.Trace(err) } for i, idx := range indices { - sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx) - idxCnt, err := getCount(ctx, sql) + stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) + if err != nil { + return 0, i, errors.Trace(err) + } + idxCnt, err := getCount(exec, stmt, snapshot) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index f265e4dd0603f..6d3c116cdb0a4 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -14,7 +14,7 @@ package gcutil import ( - "fmt" + "context" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -25,18 +25,20 @@ import ( ) const ( - selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name='%s'` - insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY - UPDATE variable_value = '%[2]s', comment = '%[3]s'` + insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) + ON DUPLICATE KEY UPDATE variable_value = %?, comment = %?` + selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name=%?` ) // CheckGCEnable is use to check whether GC is enable. func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_enable") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) - if err != nil { - return false, errors.Trace(err) + stmt, err1 := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_enable") + if err1 != nil { + return false, errors.Trace(err1) + } + rows, _, err2 := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + if err1 != nil { + return false, errors.Trace(err2) } if len(rows) != 1 { return false, errors.New("can not get 'tikv_gc_enable'") @@ -46,15 +48,19 @@ func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) { // DisableGC will disable GC enable variable. func DisableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "false", "Current GC enable status", "false", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } // EnableGC will enable GC enable variable. func EnableGC(ctx sessionctx.Context) error { - sql := fmt.Sprintf(insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status") - _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), insertVariableValueSQL, "tikv_gc_enable", "true", "Current GC enable status", "true", "Current GC enable status") + if err == nil { + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + } return errors.Trace(err) } @@ -80,8 +86,12 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { - sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_safe_point") - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") + if err != nil { + return 0, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return 0, errors.Trace(err) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 6ea589ec5ff5d..92d6958a38667 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -42,6 +42,43 @@ type RestrictedSQLExecutor interface { // If current session sets the snapshot timestamp, then execute with this snapshot timestamp. // Otherwise, execute with the current transaction start timestamp if the transaction is valid. ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast.ResultField, error) + + // The above methods are all deprecated. + // After the refactor finish, they will be removed. + + // ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4. + // It works like printf() in c, there are following format specifiers: + // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) + // 2. %%: output % + // 3. %n: for identifiers, for example ("use %n", db) + // + // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". + // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. + // This function only saves you from processing potentially unsafe parameters. + ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ExecRestrictedStmt run sql statement in ctx with some restriction. + ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) +} + +// ExecOption is a struct defined for ExecRestrictedSQLWithContext option. +type ExecOption struct { + IgnoreWarning bool + SnapshotTS uint64 +} + +// OptionFuncAlias is defined for the optional paramater of ExecRestrictedSQLWithContext. +type OptionFuncAlias = func(option *ExecOption) + +// ExecOptionIgnoreWarning tells ExecRestrictedSQLWithContext to ignore the warnings. +var ExecOptionIgnoreWarning OptionFuncAlias = func(option *ExecOption) { + option.IgnoreWarning = true +} + +// ExecOptionWithSnapshot tells ExecRestrictedSQLWithContext to use a snapshot. +func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { + return func(option *ExecOption) { + option.SnapshotTS = snapshot + } } // SQLExecutor is an interface provides executing normal sql statement. From f587ee38d5da9264661ba31e8ac960be527ed9c1 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 4 Mar 2021 10:14:54 +0800 Subject: [PATCH 10/85] sessionctx: add optimization-time and wait-TS-time into the slow log (#17869) (#22918) --- executor/adapter.go | 2 ++ executor/slow_query.go | 8 ++++++++ executor/slow_query_test.go | 2 +- infoschema/tables.go | 2 ++ infoschema/tables_test.go | 6 ++++-- planner/optimize.go | 3 +++ session/session.go | 3 +++ sessionctx/variable/session.go | 15 +++++++++++++++ sessionctx/variable/session_test.go | 4 ++++ 9 files changed, 42 insertions(+), 3 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index cc07ea991a6c6..8a7974e35a081 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -907,6 +907,8 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { TimeTotal: costTime, TimeParse: sessVars.DurationParse, TimeCompile: sessVars.DurationCompile, + TimeOptimize: sessVars.DurationOptimization, + TimeWaitTS: sessVars.DurationWaitTS, IndexNames: indexNames, StatsInfos: statsInfos, CopTasks: copTaskInfo, diff --git a/executor/slow_query.go b/executor/slow_query.go index b4bd930127ab0..68a5cf84f3a42 100755 --- a/executor/slow_query.go +++ b/executor/slow_query.go @@ -457,6 +457,8 @@ type slowQueryTuple struct { rewriteTime float64 preprocSubqueries uint64 preprocSubQueryTime float64 + optimizeTime float64 + waitTSTime float64 preWriteTime float64 waitPrewriteBinlogTime float64 commitTime float64 @@ -563,6 +565,10 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, st.parseTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCompileTimeStr: st.compileTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogOptimizeTimeStr: + st.optimizeTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogWaitTSTimeStr: + st.waitTSTime, err = strconv.ParseFloat(value, 64) case execdetails.PreWriteTimeStr: st.preWriteTime, err = strconv.ParseFloat(value, 64) case execdetails.WaitPrewriteBinlogTimeStr: @@ -687,6 +693,8 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { record = append(record, types.NewFloat64Datum(st.rewriteTime)) record = append(record, types.NewUintDatum(st.preprocSubqueries)) record = append(record, types.NewFloat64Datum(st.preprocSubQueryTime)) + record = append(record, types.NewFloat64Datum(st.optimizeTime)) + record = append(record, types.NewFloat64Datum(st.waitTSTime)) record = append(record, types.NewFloat64Datum(st.preWriteTime)) record = append(record, types.NewFloat64Datum(st.waitPrewriteBinlogTime)) record = append(record, types.NewFloat64Datum(st.commitTime)) diff --git a/executor/slow_query_test.go b/executor/slow_query_test.go index 64c63a08e5f94..09aaf9ad77e6e 100644 --- a/executor/slow_query_test.go +++ b/executor/slow_query_test.go @@ -130,7 +130,7 @@ select * from t;` } expectRecordString := `2019-04-28 15:24:04.309074,` + `405888132465033227,root,localhost,0,57,0.12,0.216905,` + - `0,0,0,0,0,0,0,0,0,0,,0,0,0,0,0,0,0.38,0.021,0,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,` + + `0,0,0,0,0,0,0,0,0,0,0,0,,0,0,0,0,0,0,0.38,0.021,0,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,` + `0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,65536,0,0,0,0,` + `Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2 Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2 Cop_backoff_rpcTiKV_total_times: 200 Cop_backoff_rpcTiKV_total_time: 0.2 Cop_backoff_rpcTiKV_max_time: 0.2 Cop_backoff_rpcTiKV_max_addr: 127.0.0.1 Cop_backoff_rpcTiKV_avg_time: 0.2 Cop_backoff_rpcTiKV_p90_time: 0.2,` + `0,0,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + diff --git a/infoschema/tables.go b/infoschema/tables.go index fb0336dd07ff4..e4dd6fa2f7165 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -728,6 +728,8 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogRewriteTimeStr, tp: mysql.TypeDouble, size: 22}, {name: variable.SlowLogPreprocSubQueriesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, {name: variable.SlowLogPreProcSubQueryTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogOptimizeTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWaitTSTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.PreWriteTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.WaitPrewriteBinlogTimeStr, tp: mysql.TypeDouble, size: 22}, {name: execdetails.CommitTimeStr, tp: mysql.TypeDouble, size: 22}, diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 072cd1079ba92..6759efa55f8c1 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -552,6 +552,8 @@ func prepareSlowLogfile(c *C, slowLogFileName string) { # Parse_time: 0.4 # Compile_time: 0.2 # Rewrite_time: 0.000000003 Preproc_subqueries: 2 Preproc_subqueries_time: 0.000000002 +# Optimize_time: 0.00000001 +# Wait_TS: 0.000000003 # LockKeys_time: 1.71 Request_count: 1 Prewrite_time: 0.19 Wait_prewrite_binlog_time: 0.21 Commit_time: 0.01 Commit_backoff_time: 0.18 Backoff_types: [txnLock] Resolve_lock_time: 0.03 Write_keys: 15 Write_size: 480 Prewrite_region: 1 Txn_retry: 8 # Cop_time: 0.3824278 Process_time: 0.161 Request_count: 1 Total_keys: 100001 Process_keys: 100000 # Wait_time: 0.101 @@ -636,10 +638,10 @@ func (s *testTableSuite) TestSlowQuery(c *C) { tk.MustExec("set time_zone = '+08:00';") re := tk.MustQuery("select * from information_schema.slow_query") re.Check(testutil.RowsWithSep("|", - "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) tk.MustExec("set time_zone = '+00:00';") re = tk.MustQuery("select * from information_schema.slow_query") - re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) // Test for long query. f, err := os.OpenFile(slowLogFileName, os.O_CREATE|os.O_WRONLY, 0644) diff --git a/planner/optimize.go b/planner/optimize.go index 35c5981d36eb7..09d2de310616a 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -248,7 +248,10 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in finalPlan, cost, err := cascades.DefaultOptimizer.FindBestPlan(sctx, logic) return finalPlan, names, cost, err } + + beginOpt := time.Now() finalPlan, cost, err := plannercore.DoOptimize(ctx, sctx, builder.GetOptFlag(), logic) + sctx.GetSessionVars().DurationOptimization = time.Since(beginOpt) return finalPlan, names, cost, err } diff --git a/session/session.go b/session/session.go index faa6ec06dfcc8..0c76f0015c9da 100644 --- a/session/session.go +++ b/session/session.go @@ -1587,6 +1587,9 @@ func (s *session) Txn(active bool) (kv.Transaction, error) { return &s.txn, errors.AddStack(kv.ErrInvalidTxn) } if s.txn.pending() && active { + defer func(begin time.Time) { + s.sessionVars.DurationWaitTS = time.Since(begin) + }(time.Now()) // Transaction is lazy initialized. // PrepareTxnCtx is called to get a tso future, makes s.txn a pending txn, // If Txn() is called later, wait for the future to get a valid txn. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 53e018bfd52e1..d9c1c44fcca9d 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -588,6 +588,12 @@ type SessionVars struct { // RewritePhaseInfo records all information about the rewriting phase. RewritePhaseInfo + // DurationOptimization is the duration of optimizing a query. + DurationOptimization time.Duration + + // DurationWaitTS is the duration of waiting for a snapshot TS + DurationWaitTS time.Duration + // PrevStmt is used to store the previous executed statement in the current session. PrevStmt fmt.Stringer @@ -1589,6 +1595,10 @@ const ( SlowLogCompileTimeStr = "Compile_time" // SlowLogRewriteTimeStr is the rewrite time. SlowLogRewriteTimeStr = "Rewrite_time" + // SlowLogOptimizeTimeStr is the optimization time. + SlowLogOptimizeTimeStr = "Optimize_time" + // SlowLogWaitTSTimeStr is the time of waiting TS. + SlowLogWaitTSTimeStr = "Wait_TS" // SlowLogPreprocSubQueriesStr is the number of pre-processed sub-queries. SlowLogPreprocSubQueriesStr = "Preproc_subqueries" // SlowLogPreProcSubQueryTimeStr is the total time of pre-processing sub-queries. @@ -1674,6 +1684,8 @@ type SlowQueryLogItems struct { TimeTotal time.Duration TimeParse time.Duration TimeCompile time.Duration + TimeOptimize time.Duration + TimeWaitTS time.Duration IndexNames string StatsInfos map[string]uint64 CopTasks *stmtctx.CopTasksDetails @@ -1754,6 +1766,9 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { } buf.WriteString("\n") + writeSlowLogItem(&buf, SlowLogOptimizeTimeStr, strconv.FormatFloat(logItems.TimeOptimize.Seconds(), 'f', -1, 64)) + writeSlowLogItem(&buf, SlowLogWaitTSTimeStr, strconv.FormatFloat(logItems.TimeWaitTS.Seconds(), 'f', -1, 64)) + if execDetailStr := logItems.ExecDetail.String(); len(execDetailStr) > 0 { buf.WriteString(SlowLogRowPrefixStr + execDetailStr + "\n") } diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index b58ab81a6b4d9..7ab9927005305 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -183,6 +183,8 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # Parse_time: 0.00000001 # Compile_time: 0.00000001 # Rewrite_time: 0.000000003 Preproc_subqueries: 2 Preproc_subqueries_time: 0.000000002 +# Optimize_time: 0.00000001 +# Wait_TS: 0.000000003 # Process_time: 2 Wait_time: 60 Backoff_time: 0.001 Request_count: 2 Total_keys: 10000 Process_keys: 20001 # DB: test # Index_names: [t1:a,t2:b] @@ -214,6 +216,8 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { TimeTotal: costTime, TimeParse: time.Duration(10), TimeCompile: time.Duration(10), + TimeOptimize: time.Duration(10), + TimeWaitTS: time.Duration(3), IndexNames: "[t1:a,t2:b]", StatsInfos: statsInfos, CopTasks: copTasks, From 38f9bdd81dce8fd41dc49465514995e003520be2 Mon Sep 17 00:00:00 2001 From: xiongjiwei Date: Thu, 4 Mar 2021 13:42:54 +0800 Subject: [PATCH 11/85] parser: make some char constant do not restore with charset prefix (#23083) --- executor/ddl_test.go | 10 ++++++++++ go.mod | 4 ++-- go.sum | 8 ++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/executor/ddl_test.go b/executor/ddl_test.go index 9b0b5c3443da1..e46302e545f16 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -291,6 +291,16 @@ func (s *testSuite6) TestViewRecursion(c *C) { tk.MustExec("drop view recursive_view1, t") } +func (s *testSuite6) TestIssue23027(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t23027") + tk.MustExec("create table t23027(a char(10))") + tk.MustExec("insert into t23027 values ('a'), ('a')") + tk.MustExec("create definer='root'@'localhost' view v23027 as select group_concat(a) from t23027;") + tk.MustQuery("select * from v23027").Check(testkit.Rows("a,a")) +} + func (s *testSuite6) TestIssue16250(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/go.mod b/go.mod index e967e7cbdab41..b9ea7f5039c62 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8 - github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe + github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966 github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible github.com/pingcap/tipb v0.0.0-20200618092958-4fad48b4c8c3 @@ -74,7 +74,7 @@ require ( golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 golang.org/x/sys v0.0.0-20200819171115-d785dc25833f - golang.org/x/text v0.3.4 + golang.org/x/text v0.3.5 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.0.0-20200820010801-b793a1359eac golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index 653411200fc3e..3c777e84def96 100644 --- a/go.sum +++ b/go.sum @@ -403,8 +403,8 @@ github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIf github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8 h1:M+DNpOu/I3uDmwee6vcnoPd6GgSMqND4gxvDQ/W584U= github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= -github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe h1:sukVKRva68HNGZ4nuPvQS/wMvH7NMxTXV2NIhmoYP4Y= -github.com/pingcap/parser v0.0.0-20210107054750-53e33b4018fe/go.mod h1:GbEr2PgY72/4XqPZzmzstlOU/+il/wrjeTNFs6ihsSE= +github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed h1:+ENLMPNRG8+/YGNJChC5QRgfrcmFnsrHl9WoVLXRZok= +github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed/go.mod h1:GbEr2PgY72/4XqPZzmzstlOU/+il/wrjeTNFs6ihsSE= github.com/pingcap/sysutil v0.0.0-20200206130906-2bfa6dc40bcd/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966 h1:JI0wOAb8aQML0vAVLHcxTEEC0VIwrk6gtw3WjbHvJLA= github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966/go.mod h1:EB/852NMQ+aRKioCpToQ94Wl7fktV+FNnxf3CX/TTXI= @@ -693,8 +693,8 @@ golang.org/x/sys v0.0.0-20200819171115-d785dc25833f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= -golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 649d0e0143acf323b0f80bb4ceef56249364eac4 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 5 Mar 2021 13:22:54 +0800 Subject: [PATCH 12/85] *: refactor ExecuteInternal to return single resultset (#22546) (#22640) --- bindinfo/handle.go | 23 +++++-------- session/bootstrap.go | 12 +++---- session/session.go | 46 ++++++------------------- store/tikv/gcworker/gc_worker.go | 8 ++--- util/mock/context.go | 4 +-- util/sqlexec/restricted_sql_executor.go | 2 +- 6 files changed, 31 insertions(+), 64 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 93d4b65391452..146eb0c96fa0c 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -661,16 +661,14 @@ func (h *BindHandle) CaptureBaselines() { func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { oriVals := sctx.GetSessionVars().UsePlanBaselines sctx.GetSessionVars().UsePlanBaselines = false - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) sctx.GetSessionVars().UsePlanBaselines = oriVals - if len(recordSets) > 0 { - defer terror.Log(recordSets[0].Close()) - } if err != nil { return "", err } - chk := recordSets[0].NewChunk() - err = recordSets[0].Next(context.TODO(), chk) + defer terror.Call(rs.Close) + chk := rs.NewChunk() + err = rs.Next(context.TODO(), chk) if err != nil { return "", err } @@ -873,23 +871,20 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan resultChan <- fmt.Errorf("run sql panicked: %v", string(buf)) } }() - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { - if len(recordSets) > 0 { - terror.Call(recordSets[0].Close) - } + terror.Call(rs.Close) resultChan <- err return } - recordSet := recordSets[0] - chk := recordSets[0].NewChunk() + chk := rs.NewChunk() for { - err = recordSet.Next(ctx, chk) + err = rs.Next(ctx, chk) if err != nil || chk.NumRows() == 0 { break } } - terror.Call(recordSets[0].Close) + terror.Call(rs.Close) resultChan <- err } diff --git a/session/bootstrap.go b/session/bootstrap.go index b85bb79e7de13..19834e428156b 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1160,8 +1160,8 @@ func upgradeToVer51(s Session, ver int64) { mustExecute(s, "COMMIT") }() mustExecute(s, h.LockBindInfoSQL()) - var recordSets []sqlexec.RecordSet - recordSets, err = s.ExecuteInternal(context.Background(), + var rs sqlexec.RecordSet + rs, err = s.ExecuteInternal(context.Background(), `SELECT bind_sql, default_db, status, create_time, charset, collation, source FROM mysql.bind_info WHERE source != 'builtin' @@ -1169,15 +1169,13 @@ func upgradeToVer51(s Session, ver int64) { if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } - if len(recordSets) > 0 { - defer terror.Call(recordSets[0].Close) - } - req := recordSets[0].NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() iter := chunk.NewIterator4Chunk(req) p := parser.New() now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) for { - err = recordSets[0].Next(context.TODO(), req) + err = rs.Next(context.TODO(), req) if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } diff --git a/session/session.go b/session/session.go index 0c76f0015c9da..d05e914e59e89 100644 --- a/session/session.go +++ b/session/session.go @@ -105,7 +105,7 @@ type Session interface { // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. - ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) + ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error RollbackTxn(context.Context) @@ -860,37 +860,17 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) startTime := time.Now() - recordSets, err := se.ExecuteInternal(ctx, sql) - defer func() { - for _, rs := range recordSets { - closeErr := rs.Close() - if closeErr != nil && err == nil { - err = closeErr - } - } - }() - if err != nil { + rs, err := se.ExecuteInternal(ctx, sql) + if err != nil || rs == nil { return nil, nil, err } - - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - // Execute all recordset, take out the first one as result. - for i, rs := range recordSets { - tmp, err := drainRecordSet(ctx, se, rs) - if err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } + defer terror.Call(rs.Close) + rows, err := drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err } metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) - return rows, fields, err + return rows, rs.Fields(), err } func createSessionFunc(store kv.Storage) pools.Factory { @@ -1123,7 +1103,7 @@ func (rs *execStmtResult) Close() error { return finishStmt(context.Background(), se, err, rs.sql) } -func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) { +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) { origin := s.sessionVars.InRestrictedSQL s.sessionVars.InRestrictedSQL = true defer func() { @@ -1142,15 +1122,11 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter return nil, err } - rs, err := s.ExecuteStmt(ctx, stmt) + rs, err = s.ExecuteStmt(ctx, stmt) if err != nil { s.sessionVars.StmtCtx.AppendError(err) } - if rs == nil { - return nil, err - } - - return []sqlexec.RecordSet{rs}, err + return rs, err } func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index b606dec2b1fac..41c7427779db5 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -1733,14 +1733,12 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { se := createSession(w.store) defer se.Close() rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) - if len(rs) > 0 { - defer terror.Call(rs[0].Close) - } if err != nil { return "", errors.Trace(err) } - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil { return "", errors.Trace(err) } diff --git a/util/mock/context.go b/util/mock/context.go index 13c878c822bed..f049a05beb842 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -62,8 +62,8 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, } // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. -func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) { - return nil, errors.Errorf("Not Support.") +func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") } type mockDDLOwnerChecker struct{} diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 92d6958a38667..cc23bda405b9f 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -88,7 +88,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. - ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error) + ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error) } // SQLParser is an interface provides parsing sql statement. From 31fd1f3e1e51542324211d5abbf045f7dc8fba8e Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sun, 7 Mar 2021 15:30:54 +0800 Subject: [PATCH 13/85] metric: record prepare execute fail as "Failed Query OPM" in monitor (#22596) (#22672) --- server/conn_stmt.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/conn_stmt.go b/server/conn_stmt.go index c8e8545c43d08..402f4aae61905 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" @@ -112,6 +113,11 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error { func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err error) { defer trace.StartRegion(ctx, "HandleStmtExecute").End() + defer func() { + if err != nil { + metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() + } + }() if len(data) < 9 { return mysql.ErrMalformPacket } From 6f0ebb55182c95604d489789f5a5fbb0cbd6e093 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 8 Mar 2021 16:38:54 +0800 Subject: [PATCH 14/85] brie/: add GetVersion function for tidbGlueSession (#22731) (#23143) --- .github/labeler.yml | 4 ++++ executor/brie.go | 5 +++++ executor/brie_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+) create mode 100644 executor/brie_test.go diff --git a/.github/labeler.yml b/.github/labeler.yml index 5040b7cfc9da0..12a5b495d6a0e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,6 +1,7 @@ component/executor: - distsql/* - executor/* + - !executor/brie* - util/chunk/* - util/disk/* - util/execdetails/* @@ -31,3 +32,6 @@ component/DDL: component/config: - config/* + +sig/migrate: + - executor/brie* diff --git a/executor/brie.go b/executor/brie.go index 4dd86aa9fdb1b..0d14265a8fd0e 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/printer" "github.com/pingcap/tidb/util/sqlexec" ) @@ -465,3 +466,7 @@ func (gs *tidbGlueSession) Record(name string, value uint64) { gs.info.archiveSize = value } } + +func (gs *tidbGlueSession) GetVersion() string { + return "TiDB\n" + printer.GetTiDBInfo() +} diff --git a/executor/brie_test.go b/executor/brie_test.go new file mode 100644 index 0000000000000..0116009efc368 --- /dev/null +++ b/executor/brie_test.go @@ -0,0 +1,28 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import . "github.com/pingcap/check" + +type testBRIESuite struct{} + +var _ = Suite(&testBRIESuite{}) + +func (s *testBRIESuite) TestGlueGetVersion(c *C) { + g := tidbGlueSession{} + version := g.GetVersion() + c.Assert(version, Matches, `(.|\n)*Release Version(.|\n)*`) + c.Assert(version, Matches, `(.|\n)*Git Commit Hash(.|\n)*`) + c.Assert(version, Matches, `(.|\n)*GoVersion(.|\n)*`) +} From a7199ff916489749b0fd279f74a63e5d9ac7cebc Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Mon, 8 Mar 2021 19:04:54 +0800 Subject: [PATCH 15/85] misc: Update kvproto (#23174) --- go.mod | 2 +- go.sum | 4 ++-- store/tikv/batch_coprocessor.go | 4 ++-- store/tikv/coprocessor.go | 16 ++++++++-------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index b9ea7f5039c62..fcf11d9da816a 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce github.com/pingcap/fn v0.0.0-20200306044125-d5540d389059 github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 - github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc + github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b github.com/pingcap/log v0.0.0-20201112100606-8f1e84a3abc8 github.com/pingcap/parser v0.0.0-20210303062609-d1d977c9ceed github.com/pingcap/sysutil v0.0.0-20201130064824-f0c8aa6a6966 diff --git a/go.sum b/go.sum index 3c777e84def96..9d904cff63883 100644 --- a/go.sum +++ b/go.sum @@ -396,8 +396,8 @@ github.com/pingcap/kvproto v0.0.0-20191211054548-3c6b38ea5107/go.mod h1:WWLmULLO github.com/pingcap/kvproto v0.0.0-20200411081810-b85805c9476c/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= github.com/pingcap/kvproto v0.0.0-20200907074027-32a3a0accf7d h1:gvJScINTd/HFasp82W1paGTfbYe2Bnzn8QBOJXckLwY= github.com/pingcap/kvproto v0.0.0-20200907074027-32a3a0accf7d/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= -github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc h1:BtszN3YR5EScxiGGTD3tAf4CQE90bczkOY0lLa07EJA= -github.com/pingcap/kvproto v0.0.0-20201126113434-70db5fb4b0dc/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= +github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b h1:Jp0V5PDzdOy666n4XbDDaEjOKHsp2nk7b2uR6qjFI0s= +github.com/pingcap/kvproto v0.0.0-20210308075244-560097d1309b/go.mod h1:IOdRDPLyda8GX2hE/jO7gqaCV/PNFh8BZQCQZXfIOqI= github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= diff --git a/store/tikv/batch_coprocessor.go b/store/tikv/batch_coprocessor.go index 56b1b1fed9f3e..118f3dffebdb6 100644 --- a/store/tikv/batch_coprocessor.go +++ b/store/tikv/batch_coprocessor.go @@ -350,8 +350,8 @@ func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *Backoffer, ta IsolationLevel: pbIsolationLevel(b.req.IsolationLevel), Priority: kvPriorityToCommandPri(b.req.Priority), NotFillCache: b.req.NotFillCache, - HandleTime: true, - ScanDetail: true, + RecordTimeStat: true, + RecordScanStat: true, TaskId: b.req.TaskID, }) req.StoreTp = kv.TiFlash diff --git a/store/tikv/coprocessor.go b/store/tikv/coprocessor.go index 9022e877d53d6..7dbdcda6e0309 100644 --- a/store/tikv/coprocessor.go +++ b/store/tikv/coprocessor.go @@ -872,8 +872,8 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch IsolationLevel: pbIsolationLevel(worker.req.IsolationLevel), Priority: kvPriorityToCommandPri(worker.req.Priority), NotFillCache: worker.req.NotFillCache, - HandleTime: true, - ScanDetail: true, + RecordTimeStat: true, + RecordScanStat: true, TaskId: worker.req.TaskID, }) req.StoreTp = task.storeType @@ -1013,9 +1013,9 @@ func (worker *copIteratorWorker) logTimeCopTask(costTime time.Duration, task *co } } - if detail != nil && detail.HandleTime != nil { - processMs := detail.HandleTime.ProcessMs - waitMs := detail.HandleTime.WaitMs + if detail != nil && detail.TimeDetail != nil { + processMs := detail.TimeDetail.ProcessWallTimeMs + waitMs := detail.TimeDetail.WaitWallTimeMs if processMs > minLogKVProcessTime { logStr += fmt.Sprintf(" kv_process_ms:%d", processMs) if detail.ScanDetail != nil { @@ -1150,9 +1150,9 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon } resp.respTime = costTime if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { - if handleTime := pbDetails.HandleTime; handleTime != nil { - resp.detail.WaitTime = time.Duration(handleTime.WaitMs) * time.Millisecond - resp.detail.ProcessTime = time.Duration(handleTime.ProcessMs) * time.Millisecond + if timeDetail := pbDetails.TimeDetail; timeDetail != nil { + resp.detail.WaitTime = time.Duration(timeDetail.WaitWallTimeMs) * time.Millisecond + resp.detail.ProcessTime = time.Duration(timeDetail.ProcessWallTimeMs) * time.Millisecond } if scanDetail := pbDetails.ScanDetail; scanDetail != nil { if scanDetail.Write != nil { From 5c1f32881437d7b5a5c31273a1c8b84cd9efd950 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 8 Mar 2021 19:24:54 +0800 Subject: [PATCH 16/85] planner: fix wrong table filters for index merge plan (#23132) (#23165) --- planner/core/integration_test.go | 48 +++++++++++++++---- planner/core/stats.go | 11 ++--- .../core/testdata/integration_suite_out.json | 11 +++-- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 5779082c2a551..a15d007dde541 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1145,10 +1145,10 @@ func (s *testIntegrationSerialSuite) TestIssue16837(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int,b int,c int,d int,e int,unique key idx_ab(a,b),unique key(c),unique key(d))") tk.MustQuery("explain select /*+ use_index_merge(t,c,idx_ab) */ * from t where a = 1 or (e = 1 and c = 1)").Check(testkit.Rows( - "IndexMerge_9 0.01 root ", + "IndexMerge_9 8.80 root ", "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:idx_ab(a, b) range:[1,1], keep order:false, stats:pseudo", "├─IndexRangeScan_6(Build) 1.00 cop[tikv] table:t, index:c(c) range:[1,1], keep order:false, stats:pseudo", - "└─Selection_8(Probe) 0.01 cop[tikv] eq(test.t.e, 1)", + "└─Selection_8(Probe) 8.80 cop[tikv] or(eq(test.t.a, 1), and(eq(test.t.e, 1), eq(test.t.c, 1)))", " └─TableRowIDScan_7 11.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustExec("insert into t values (2, 1, 1, 1, 2)") @@ -1186,10 +1186,13 @@ func (s *testIntegrationSerialSuite) TestIndexMerge(c *C) { tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustQuery("desc select /*+ use_index_merge(t) */ * from t where (a=1 and length(b)=1) or (b=1 and length(a)=1)").Check(testkit.Rows( - "TableReader_7 8000.00 root data:Selection_6", - "└─Selection_6 8000.00 cop[tikv] or(and(eq(test.t.a, 1), eq(length(cast(test.t.b)), 1)), and(eq(test.t.b, 1), eq(length(cast(test.t.a)), 1)))", - " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 IndexMerge is inapplicable or disabled")) + "IndexMerge_9 1.60 root ", + "├─IndexRangeScan_5(Build) 1.00 cop[tikv] table:t, index:a(a) range:[1,1], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 1.00 cop[tikv] table:t, index:b(b) range:[1,1], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 1.60 cop[tikv] or(and(eq(test.t.a, 1), eq(length(cast(test.t.b)), 1)), and(eq(test.t.b, 1), eq(length(cast(test.t.a)), 1)))", + " └─TableRowIDScan_7 2.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("show warnings").Check(testkit.Rows()) } func (s *testIntegrationSerialSuite) TestIssue16407(c *C) { @@ -1198,10 +1201,10 @@ func (s *testIntegrationSerialSuite) TestIssue16407(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int,b char(100),key(a),key(b(10)))") tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where a=10 or b='x'").Check(testkit.Rows( - "IndexMerge_9 0.02 root ", + "IndexMerge_9 16.00 root ", "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[\"x\",\"x\"], keep order:false, stats:pseudo", - "└─Selection_8(Probe) 0.02 cop[tikv] eq(test.t.b, \"x\")", + "└─Selection_8(Probe) 16.00 cop[tikv] or(eq(test.t.a, 10), eq(test.t.b, \"x\"))", " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustQuery("show warnings").Check(testkit.Rows()) tk.MustExec("insert into t values (1, 'xx')") @@ -1832,3 +1835,32 @@ func (s *testIntegrationSuite) TestIssue22071(c *C) { tk.MustQuery("select n in (1,2) from (select a in (1,2) as n from t) g;").Sort().Check(testkit.Rows("0", "1", "1")) tk.MustQuery("select n in (1,n) from (select a in (1,2) as n from t) g;").Check(testkit.Rows("1", "1", "1")) } + +func (s *testIntegrationSuite) TestIndexMergeTableFilter(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int, b int, c int, d int, key(a), key(b));") + tk.MustExec("insert into t values(10,1,1,10)") + + tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where a=10 or (b=10 and c=10)").Check(testkit.Rows( + "IndexMerge_9 16.00 root ", + "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[10,10], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 16.00 cop[tikv] or(eq(test.t.a, 10), and(eq(test.t.b, 10), eq(test.t.c, 10)))", + " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("select /*+ use_index_merge(t) */ * from t where a=10 or (b=10 and c=10)").Check(testkit.Rows( + "10 1 1 10", + )) + tk.MustQuery("explain select /*+ use_index_merge(t) */ * from t where (a=10 and d=10) or (b=10 and c=10)").Check(testkit.Rows( + "IndexMerge_9 16.00 root ", + "├─IndexRangeScan_5(Build) 10.00 cop[tikv] table:t, index:a(a) range:[10,10], keep order:false, stats:pseudo", + "├─IndexRangeScan_6(Build) 10.00 cop[tikv] table:t, index:b(b) range:[10,10], keep order:false, stats:pseudo", + "└─Selection_8(Probe) 16.00 cop[tikv] or(and(eq(test.t.a, 10), eq(test.t.d, 10)), and(eq(test.t.b, 10), eq(test.t.c, 10)))", + " └─TableRowIDScan_7 20.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery("select /*+ use_index_merge(t) */ * from t where (a=10 and d=10) or (b=10 and c=10)").Check(testkit.Rows( + "10 1 1 10", + )) +} diff --git a/planner/core/stats.go b/planner/core/stats.go index a062fcf50b844..6a03e53b6ca97 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -464,16 +464,11 @@ func (ds *DataSource) buildIndexMergeOrPath(partialPaths []*util.AccessPath, cur indexMergePath := &util.AccessPath{PartialIndexPaths: partialPaths} indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[:current]...) indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[current+1:]...) - tableFilterCnt := 0 for _, path := range partialPaths { - // IndexMerge should not be used when the SQL is like 'select x from t WHERE (key1=1 AND key2=2) OR (key1=4 AND key3=6);'. - // Check issue https://github.com/pingcap/tidb/issues/22105 for details. + // If any partial path contains table filters, we need to keep the whole DNF filter in the Selection. if len(path.TableFilters) > 0 { - tableFilterCnt++ - if tableFilterCnt > 1 { - return nil - } - indexMergePath.TableFilters = append(indexMergePath.TableFilters, path.TableFilters...) + indexMergePath.TableFilters = append(indexMergePath.TableFilters, ds.pushedDownConds[current]) + break } } return indexMergePath diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index fa38b1dc73407..6ebc5adc2e922 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -897,11 +897,12 @@ { "SQL": "explain SELECT /*+ use_index_merge(t1)*/ COUNT(*) FROM t1 WHERE (key4=42 AND key6 IS NOT NULL) OR (key1=4 AND key3=6)", "Plan": [ - "StreamAgg_20 1.00 root funcs:count(Column#12)->Column#10", - "└─TableReader_21 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#12", - " └─Selection_19 8000.00 cop[tikv] or(and(eq(test.t1.key4, 42), not(isnull(test.t1.key6))), and(eq(test.t1.key1, 4), eq(test.t1.key3, 6)))", - " └─TableFullScan_18 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + "HashAgg_8 1.00 root funcs:count(1)->Column#10", + "└─IndexMerge_15 16.00 root ", + " ├─IndexRangeScan_11(Build) 10.00 cop[tikv] table:t1, index:i4(key4) range:[42,42], keep order:false, stats:pseudo", + " ├─IndexRangeScan_12(Build) 10.00 cop[tikv] table:t1, index:i1(key1) range:[4,4], keep order:false, stats:pseudo", + " └─Selection_14(Probe) 16.00 cop[tikv] or(and(eq(test.t1.key4, 42), not(isnull(test.t1.key6))), and(eq(test.t1.key1, 4), eq(test.t1.key3, 6)))", + " └─TableRowIDScan_13 20.00 cop[tikv] table:t1 keep order:false, stats:pseudo" ] } ] From c1ae2c658918c9215ca3259921233b9dc87c8722 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 9 Mar 2021 12:04:55 +0800 Subject: [PATCH 17/85] executor: fix load data losing connection when batch_dml_size is set (#22724) (#22736) --- executor/builder.go | 2 ++ executor/insert_common.go | 19 ++++++++--- executor/load_data.go | 2 ++ server/server_test.go | 68 +++++++++++++++++++++++++++++++++++++++ server/tidb_test.go | 6 ++++ 5 files changed, 93 insertions(+), 4 deletions(-) diff --git a/executor/builder.go b/executor/builder.go index aef4c484f5f3e..0263f6f098ae4 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -737,6 +737,8 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor { Table: tbl, Columns: v.Columns, GenExprs: v.GenCols.Exprs, + isLoadData: true, + txnInUse: sync.Mutex{}, } err := insertVal.initInsertColumns() if err != nil { diff --git a/executor/insert_common.go b/executor/insert_common.go index 6c2f27700ed10..7c4a59fa2b3f3 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "math" + "sync" "time" "github.com/opentracing/opentracing-go" @@ -81,6 +82,12 @@ type InsertValues struct { memTracker *memory.Tracker stats *InsertRuntimeStat + + // LoadData use two goroutines. One for generate batch data, + // The other one for commit task, which will invalid txn. + // We use mutex to protect routine from using invalid txn. + isLoadData bool + txnInUse sync.Mutex } type defaultVal struct { @@ -857,10 +864,6 @@ func (e *InsertValues) adjustAutoRandomDatum(ctx context.Context, d types.Datum, // Change NULL to auto id. // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. if d.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { - _, err := e.ctx.Txn(true) - if err != nil { - return types.Datum{}, errors.Trace(err) - } recordID, err = e.allocAutoRandomID(&c.FieldType) if err != nil { return types.Datum{}, err @@ -894,6 +897,14 @@ func (e *InsertValues) allocAutoRandomID(fieldType *types.FieldType) (int64, err if tables.OverflowShardBits(autoRandomID, tableInfo.AutoRandomBits, layout.TypeBitsLength, layout.HasSignBit) { return 0, autoid.ErrAutoRandReadFailed } + if e.isLoadData { + e.txnInUse.Lock() + defer e.txnInUse.Unlock() + } + _, err = e.ctx.Txn(true) + if err != nil { + return 0, err + } shard := tables.CalcShard(tableInfo.AutoRandomBits, e.ctx.GetSessionVars().TxnCtx.StartTS, layout.TypeBitsLength, layout.HasSignBit) autoRandomID |= shard return autoRandomID, nil diff --git a/executor/load_data.go b/executor/load_data.go index d69e66c33df6d..0449a78c4b6f7 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -199,6 +199,8 @@ func (e *LoadDataInfo) CommitOneTask(ctx context.Context, task CommitTask) error logutil.Logger(ctx).Error("commit error commit", zap.Error(err)) return err } + e.txnInUse.Lock() + defer e.txnInUse.Unlock() // Make sure that there are no retries when committing. if err = e.Ctx.RefreshTxnCtx(ctx); err != nil { logutil.Logger(ctx).Error("commit error refresh", zap.Error(err)) diff --git a/server/server_test.go b/server/server_test.go index f70c8e2fbda38..9c672162c3555 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -903,6 +903,74 @@ func (cli *testServerClient) runTestLoadData(c *C, server *Server) { }) } +func (cli *testServerClient) runTestLoadDataAutoRandom(c *C) { + path := "/tmp/load_data_txn_error.csv" + + fp, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + c.Assert(err, IsNil) + c.Assert(fp, NotNil) + + defer func() { + _ = os.Remove(path) + }() + + cksum1 := 0 + cksum2 := 0 + for i := 0; i < 50000; i++ { + n1 := rand.Intn(1000) + n2 := rand.Intn(1000) + str1 := strconv.Itoa(n1) + str2 := strconv.Itoa(n2) + row := str1 + "\t" + str2 + _, err := fp.WriteString(row) + c.Assert(err, IsNil) + _, err = fp.WriteString("\n") + c.Assert(err, IsNil) + + if i == 0 { + cksum1 = n1 + cksum2 = n2 + } else { + cksum1 = cksum1 ^ n1 + cksum2 = cksum2 ^ n2 + } + } + + err = fp.Close() + c.Assert(err, IsNil) + + cli.runTestsOnNewDB(c, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params = map[string]string{"sql_mode": "''"} + }, "load_data_batch_dml", func(dbt *DBTest) { + // Set batch size, and check if load data got a invalid txn error. + dbt.mustExec("set @@session.tidb_dml_batch_size = 128") + dbt.mustExec("drop table if exists t") + dbt.mustExec("create table t(c1 bigint auto_random primary key, c2 bigint, c3 bigint)") + dbt.mustExec(fmt.Sprintf("load data local infile %q into table t (c2, c3)", path)) + + var ( + rowCnt int + colCkSum1 int + colCkSum2 int + ) + rows := dbt.mustQuery("select count(*) from t") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&rowCnt) + dbt.Check(err, IsNil) + dbt.Check(rowCnt, DeepEquals, 50000) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + + rows = dbt.mustQuery("select bit_xor(c2), bit_xor(c3) from t") + dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data")) + err = rows.Scan(&colCkSum1, &colCkSum2) + dbt.Check(err, IsNil) + dbt.Check(colCkSum1, DeepEquals, cksum1) + dbt.Check(colCkSum2, DeepEquals, cksum2) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) + }) +} + func (cli *testServerClient) runTestConcurrentUpdate(c *C) { dbName := "Concurrent" cli.runTestsOnNewDB(c, func(config *mysql.Config) { diff --git a/server/tidb_test.go b/server/tidb_test.go index 9290cd35b6d75..acb00f06348a1 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -148,6 +148,12 @@ func (ts *tidbTestSerialSuite) TestLoadData(c *C) { ts.runTestLoadDataForSlowLog(c, ts.server) } +// Fix issue#22540. Change tidb_dml_batch_size, +// then check if load data into table with auto random column works properly. +func (ts *tidbTestSerialSuite) TestLoadDataAutoRandom(c *C) { + ts.runTestLoadDataAutoRandom(c) +} + func (ts *tidbTestSerialSuite) TestExplainFor(c *C) { ts.runTestExplainForConn(c) } From 84270e964765abb103e814adfee736010a144731 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 9 Mar 2021 14:40:55 +0800 Subject: [PATCH 18/85] planner: refine explain info for batch cop (#20360) (#23164) --- executor/builder.go | 22 +--- planner/core/common_plans.go | 6 +- planner/core/initialize.go | 16 +++ planner/core/integration_test.go | 4 +- planner/core/physical_plans.go | 3 + .../integration_serial_suite_out.json | 124 +++++++++--------- 6 files changed, 89 insertions(+), 86 deletions(-) diff --git a/executor/builder.go b/executor/builder.go index 0263f6f098ae4..e8f4bca522a8f 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2297,26 +2297,6 @@ func containsLimit(execs []*tipb.Executor) bool { return false } -// When allow batch cop is 1, only agg / topN uses batch cop. -// When allow batch cop is 2, every query uses batch cop. -func (e *TableReaderExecutor) setBatchCop(v *plannercore.PhysicalTableReader) { - if e.storeType != kv.TiFlash || e.keepOrder { - return - } - switch e.ctx.GetSessionVars().AllowBatchCop { - case 1: - for _, p := range v.TablePlans { - switch p.(type) { - case *plannercore.PhysicalHashAgg, *plannercore.PhysicalStreamAgg, *plannercore.PhysicalTopN, *plannercore.PhysicalBroadCastJoin: - e.batchCop = true - } - } - case 2: - e.batchCop = true - } - return -} - func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableReader) (*TableReaderExecutor, error) { tablePlans := v.TablePlans if v.StoreType == kv.TiFlash { @@ -2351,8 +2331,8 @@ func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableRea plans: v.TablePlans, tablePlan: v.GetTablePlan(), storeType: v.StoreType, + batchCop: v.BatchCop, } - e.setBatchCop(v) e.buildVirtualColumnInfo() if containsLimit(dagReq.Executors) { e.feedback = statistics.NewQueryFeedback(0, nil, 0, ts.Desc) diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 186993c2cd599..83d97ebecfd40 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -1012,7 +1012,11 @@ func (e *Explain) explainPlanInRowFormat(p Plan, taskType, driverSide, indent st return errors.Errorf("the store type %v is unknown", x.StoreType) } storeType = x.StoreType.Name() - err = e.explainPlanInRowFormat(x.tablePlan, "cop["+storeType+"]", "", childIndent, true) + taskName := "cop" + if x.BatchCop { + taskName = "batchCop" + } + err = e.explainPlanInRowFormat(x.tablePlan, taskName+"["+storeType+"]", "", childIndent, true) case *PhysicalIndexReader: err = e.explainPlanInRowFormat(x.indexPlan, "cop[tikv]", "", childIndent, true) case *PhysicalIndexLookUpReader: diff --git a/planner/core/initialize.go b/planner/core/initialize.go index acea1c78de2d0..d238cf038dd7e 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -15,6 +15,7 @@ package core import ( "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -428,6 +429,21 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalT if p.tablePlan != nil { p.TablePlans = flattenPushDownPlan(p.tablePlan) p.schema = p.tablePlan.Schema() + if p.StoreType == kv.TiFlash && !p.GetTableScan().KeepOrder { + // When allow batch cop is 1, only agg / topN uses batch cop. + // When allow batch cop is 2, every query uses batch cop. + switch ctx.GetSessionVars().AllowBatchCop { + case 1: + for _, plan := range p.TablePlans { + switch plan.(type) { + case *PhysicalHashAgg, *PhysicalStreamAgg, *PhysicalTopN, *PhysicalBroadCastJoin: + p.BatchCop = true + } + } + case 2: + p.BatchCop = true + } + } } return &p } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index a15d007dde541..c83757a4dec6f 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -441,8 +441,8 @@ func (s *testIntegrationSerialSuite) TestAggPushDownEngine(c *C) { tk.MustQuery("desc select approx_count_distinct(a) from t").Check(testkit.Rows( "StreamAgg_16 1.00 root funcs:approx_count_distinct(Column#5)->Column#3", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:approx_count_distinct(test.t.a)->Column#5", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo")) + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:approx_count_distinct(test.t.a)->Column#5", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo")) tk.MustExec("set @@session.tidb_isolation_read_engines = 'tikv'") diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 597bfeeb49fdc..1220254d90e77 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -72,6 +72,9 @@ type PhysicalTableReader struct { // StoreType indicates table read from which type of store. StoreType kv.StoreType + + // BatchCop = true means the cop task in the physical table reader will be executed in batch mode(use in TiFlash only) + BatchCop bool } // GetTablePlan exports the tablePlan. diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index a26e1116bfc94..99027be774531 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -25,9 +25,9 @@ "└─TopN_8 2.00 root Column#3:asc, offset:0, count:2", " └─Projection_19 2.00 root test.t.a, test.t.b, cast(test.t.b, bigint(22) UNSIGNED BINARY)->Column#3", " └─TableReader_14 2.00 root data:TopN_13", - " └─TopN_13 2.00 cop[tiflash] cast(test.t.b):asc, offset:0, count:2", - " └─Selection_12 3333.33 cop[tiflash] gt(test.t.b, \"a\")", - " └─TableFullScan_11 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─TopN_13 2.00 batchCop[tiflash] cast(test.t.b):asc, offset:0, count:2", + " └─Selection_12 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", + " └─TableFullScan_11 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -35,9 +35,9 @@ "Result": [ "TopN_8 2.00 root test.t.b:asc, offset:0, count:2", "└─TableReader_17 2.00 root data:TopN_16", - " └─TopN_16 2.00 cop[tiflash] test.t.b:asc, offset:0, count:2", - " └─Selection_15 3333.33 cop[tiflash] gt(test.t.b, \"a\")", - " └─TableFullScan_14 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─TopN_16 2.00 batchCop[tiflash] test.t.b:asc, offset:0, count:2", + " └─Selection_15 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", + " └─TableFullScan_14 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] } ] @@ -50,12 +50,12 @@ "Plan": [ "StreamAgg_24 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_25 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_23 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_19(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_18 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_17(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_16 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_9 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_23 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_19(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_18 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_17(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_16 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -63,18 +63,18 @@ "Plan": [ "StreamAgg_44 1.00 root funcs:count(Column#19)->Column#17", "└─TableReader_45 1.00 root data:StreamAgg_13", - " └─StreamAgg_13 1.00 cop[tiflash] funcs:count(1)->Column#19", - " └─BroadcastJoin_43 8.00 cop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", - " ├─Selection_39(Build) 2.00 cop[tiflash] not(isnull(test.d3_t.d3_k))", - " │ └─TableFullScan_38 2.00 cop[tiflash] table:d3_t keep order:false, global read", - " └─BroadcastJoin_29(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", - " ├─Selection_25(Build) 2.00 cop[tiflash] not(isnull(test.d2_t.d2_k))", - " │ └─TableFullScan_24 2.00 cop[tiflash] table:d2_t keep order:false, global read", - " └─BroadcastJoin_33(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_23(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_22 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_37(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", - " └─TableFullScan_36 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_13 1.00 batchCop[tiflash] funcs:count(1)->Column#19", + " └─BroadcastJoin_43 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", + " ├─Selection_39(Build) 2.00 batchCop[tiflash] not(isnull(test.d3_t.d3_k))", + " │ └─TableFullScan_38 2.00 batchCop[tiflash] table:d3_t keep order:false, global read", + " └─BroadcastJoin_29(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", + " ├─Selection_25(Build) 2.00 batchCop[tiflash] not(isnull(test.d2_t.d2_k))", + " │ └─TableFullScan_24 2.00 batchCop[tiflash] table:d2_t keep order:false, global read", + " └─BroadcastJoin_33(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_23(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_22 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_37(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", + " └─TableFullScan_36 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -82,12 +82,12 @@ "Plan": [ "StreamAgg_18 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_19 1.00 root data:StreamAgg_9", - " └─StreamAgg_9 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_17 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_14(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_13 2.00 cop[tiflash] table:d1_t keep order:false", - " └─Selection_12(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_11 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_9 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_17 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_14(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_13 2.00 batchCop[tiflash] table:d1_t keep order:false", + " └─Selection_12(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_11 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] }, { @@ -95,18 +95,18 @@ "Plan": [ "StreamAgg_29 1.00 root funcs:count(Column#19)->Column#17", "└─TableReader_30 1.00 root data:StreamAgg_13", - " └─StreamAgg_13 1.00 cop[tiflash] funcs:count(1)->Column#19", - " └─BroadcastJoin_28 8.00 cop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", - " ├─Selection_25(Build) 2.00 cop[tiflash] not(isnull(test.d3_t.d3_k))", - " │ └─TableFullScan_24 2.00 cop[tiflash] table:d3_t keep order:false, global read", - " └─BroadcastJoin_15(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", - " ├─Selection_23(Build) 2.00 cop[tiflash] not(isnull(test.d2_t.d2_k))", - " │ └─TableFullScan_22 2.00 cop[tiflash] table:d2_t keep order:false", - " └─BroadcastJoin_16(Probe) 8.00 cop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_21(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_20 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─Selection_19(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", - " └─TableFullScan_18 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_13 1.00 batchCop[tiflash] funcs:count(1)->Column#19", + " └─BroadcastJoin_28 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d3_k, right key:test.d3_t.d3_k", + " ├─Selection_25(Build) 2.00 batchCop[tiflash] not(isnull(test.d3_t.d3_k))", + " │ └─TableFullScan_24 2.00 batchCop[tiflash] table:d3_t keep order:false, global read", + " └─BroadcastJoin_15(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d2_k, right key:test.d2_t.d2_k", + " ├─Selection_23(Build) 2.00 batchCop[tiflash] not(isnull(test.d2_t.d2_k))", + " │ └─TableFullScan_22 2.00 batchCop[tiflash] table:d2_t keep order:false", + " └─BroadcastJoin_16(Probe) 8.00 batchCop[tiflash] inner join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_21(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_20 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─Selection_19(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))", + " └─TableFullScan_18 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] }, { @@ -114,11 +114,11 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_15 8.00 cop[tiflash] left outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─Selection_12(Build) 2.00 cop[tiflash] not(isnull(test.d1_t.d1_k))", - " │ └─TableFullScan_11 2.00 cop[tiflash] table:d1_t keep order:false, global read", - " └─TableFullScan_10(Probe) 8.00 cop[tiflash] table:fact_t keep order:false" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_15 8.00 batchCop[tiflash] left outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─Selection_12(Build) 2.00 batchCop[tiflash] not(isnull(test.d1_t.d1_k))", + " │ └─TableFullScan_11 2.00 batchCop[tiflash] table:d1_t keep order:false, global read", + " └─TableFullScan_10(Probe) 8.00 batchCop[tiflash] table:fact_t keep order:false" ] }, { @@ -126,11 +126,11 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:count(Column#13)->Column#11", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(1)->Column#13", - " └─BroadcastJoin_15 8.00 cop[tiflash] right outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", - " ├─TableFullScan_12(Build) 2.00 cop[tiflash] table:d1_t keep order:false", - " └─Selection_11(Probe) 8.00 cop[tiflash] not(isnull(test.fact_t.d1_k))", - " └─TableFullScan_10 8.00 cop[tiflash] table:fact_t keep order:false, global read" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(1)->Column#13", + " └─BroadcastJoin_15 8.00 batchCop[tiflash] right outer join, left key:test.fact_t.d1_k, right key:test.d1_t.d1_k", + " ├─TableFullScan_12(Build) 2.00 batchCop[tiflash] table:d1_t keep order:false", + " └─Selection_11(Probe) 8.00 batchCop[tiflash] not(isnull(test.fact_t.d1_k))", + " └─TableFullScan_10 8.00 batchCop[tiflash] table:fact_t keep order:false, global read" ] } ] @@ -143,8 +143,8 @@ "Plan": [ "StreamAgg_24 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader_25 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan_22 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", + " └─TableFullScan_22 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -153,8 +153,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -163,8 +163,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(test.t.a)->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(test.t.a)->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -173,8 +173,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(plus(test.t.a, 1))->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(plus(test.t.a, 1))->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -183,8 +183,8 @@ "Plan": [ "StreamAgg_16 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader_17 1.00 root data:StreamAgg_8", - " └─StreamAgg_8 1.00 cop[tiflash] funcs:sum(isnull(test.t.a))->Column#6", - " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg_8 1.00 batchCop[tiflash] funcs:sum(isnull(test.t.a))->Column#6", + " └─TableFullScan_15 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, From f29892c50ccf978d22f4b1e7e7c18601ce8c8b23 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 9 Mar 2021 14:54:54 +0800 Subject: [PATCH 19/85] executor: add new format specifier(%# %@ %.) for str_to_date expression (#22790) (#22812) --- expression/builtin_time_test.go | 9 ++++++++ types/time.go | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 3ee725e78d578..16fcaa8fbfaa9 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1411,6 +1411,15 @@ func (s *testEvaluatorSuite) TestStrToDate(c *C) { {"15-01-2001 1:9:8.999", "%d-%m-%Y %H:%i:%S.%f", true, time.Date(2001, 1, 15, 1, 9, 8, 999000000, time.Local)}, {"2003-01-02 10:11:12 PM", "%Y-%m-%d %H:%i:%S %p", false, time.Time{}}, {"10:20:10AM", "%H:%i:%S%p", false, time.Time{}}, + // test %@(skip alpha), %#(skip number), %.(skip punct) + {"2020-10-10ABCD", "%Y-%m-%d%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-101234", "%Y-%m-%d%#", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10....", "%Y-%m-%d%.", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10.1", "%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"abcd2020-10-10.1", "%@%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"abcd-2020-10-10.1", "%@-%Y-%m-%d%.%#%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10", "%Y-%m-%d%@", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, + {"2020-10-10abcde123abcdef", "%Y-%m-%d%@%#", true, time.Date(2020, 10, 10, 0, 0, 0, 0, time.Local)}, } fc := funcs[ast.StrToDate] diff --git a/types/time.go b/types/time.go index 842c62daaf5f8..3e86a1e9b88a0 100644 --- a/types/time.go +++ b/types/time.go @@ -2864,6 +2864,9 @@ var dateFormatParserTable = map[string]dateFormatParser{ "%S": secondsNumeric, // Seconds (00..59) "%T": time24Hour, // Time, 24-hour (hh:mm:ss) "%Y": yearNumericFourDigits, // Year, numeric, four digits + "%#": skipAllNums, // Skip all numbers + "%.": skipAllPunct, // Skip all punctation characters + "%@": skipAllAlpha, // Skip all alpha characters // Deprecated since MySQL 5.7.5 "%y": yearNumericTwoDigits, // Year, numeric (two digits) // TODO: Add the following... @@ -3242,3 +3245,39 @@ func DateTimeIsOverflow(sc *stmtctx.StatementContext, date Time) (bool, error) { inRange := (t.After(b) || t.Equal(b)) && (t.Before(e) || t.Equal(e)) return !inRange, nil } + +func skipAllNums(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsNumber(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} + +func skipAllPunct(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsPunct(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} + +func skipAllAlpha(t *CoreTime, input string, ctx map[string]int) (string, bool) { + retIdx := 0 + for i, ch := range input { + if unicode.IsLetter(ch) { + retIdx = i + 1 + } else { + break + } + } + return input[retIdx:], true +} From 142fd860ae699f354ab01a86347002e975847f80 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 9 Mar 2021 15:30:55 +0800 Subject: [PATCH 20/85] privilege: remove any string concat (#22523) (#22689) --- privilege/privileges/cache.go | 40 ++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 7d39c46f31843..6c59e6aa3970b 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -47,6 +47,23 @@ var ( const globalDBVisible = mysql.CreatePriv | mysql.SelectPriv | mysql.InsertPriv | mysql.UpdatePriv | mysql.DeletePriv | mysql.ShowDBPriv | mysql.DropPriv | mysql.AlterPriv | mysql.IndexPriv | mysql.CreateViewPriv | mysql.ShowViewPriv | mysql.GrantPriv | mysql.TriggerPriv | mysql.ReferencesPriv | mysql.ExecutePriv +const ( + sqlLoadRoleGraph = "SELECT HIGH_PRIORITY FROM_USER, FROM_HOST, TO_USER, TO_HOST FROM mysql.role_edges" + sqlLoadGlobalPrivTable = "SELECT HIGH_PRIORITY Host,User,Priv FROM mysql.global_priv" + sqlLoadDBTable = "SELECT HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv FROM mysql.db ORDER BY host, db, user" + sqlLoadTablePrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv FROM mysql.tables_priv" + sqlLoadColumnsPrivTable = "SELECT HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv FROM mysql.columns_priv" + sqlLoadDefaultRoles = "SELECT HIGH_PRIORITY HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER FROM mysql.default_roles" + // list of privileges from mysql.Priv2UserCol + sqlLoadUserTable = `SELECT HIGH_PRIORITY Host,User,authentication_string, + Create_priv, Select_priv, Insert_priv, Update_priv, Delete_priv, Show_db_priv, Super_priv, + Create_user_priv,Trigger_priv,Drop_priv,Process_priv,Grant_priv, + References_priv,Alter_priv,Execute_priv,Index_priv,Create_view_priv,Show_view_priv, + Create_role_priv,Drop_role_priv,Create_tmp_table_priv,Lock_tables_priv,Create_routine_priv, + Alter_routine_priv,Event_priv,Shutdown_priv,Reload_priv,File_priv,Config_priv, + account_locked FROM mysql.user` +) + func computePrivMask(privs []mysql.PrivilegeType) mysql.PrivilegeType { var mask mysql.PrivilegeType for _, p := range privs { @@ -347,7 +364,7 @@ func noSuchTable(err error) bool { // LoadRoleGraph loads the mysql.role_edges table from database. func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error { p.RoleGraph = make(map[string]roleGraphEdgesTable) - err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from mysql.role_edges;", p.decodeRoleEdgesTable) + err := p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable) if err != nil { return errors.Trace(err) } @@ -356,12 +373,7 @@ func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error { // LoadUserTable loads the mysql.user table from database. func (p *MySQLPrivilege) LoadUserTable(ctx sessionctx.Context) error { - userPrivCols := make([]string, 0, len(mysql.Priv2UserCol)) - for _, v := range mysql.Priv2UserCol { - userPrivCols = append(userPrivCols, v) - } - query := fmt.Sprintf("select HIGH_PRIORITY Host,User,authentication_string,%s,account_locked from mysql.user;", strings.Join(userPrivCols, ", ")) - err := p.loadTable(ctx, query, p.decodeUserTableRow) + err := p.loadTable(ctx, sqlLoadUserTable, p.decodeUserTableRow) if err != nil { return errors.Trace(err) } @@ -467,12 +479,12 @@ func (p MySQLPrivilege) SortUserTable() { // LoadGlobalPrivTable loads the mysql.global_priv table from database. func (p *MySQLPrivilege) LoadGlobalPrivTable(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Priv from mysql.global_priv", p.decodeGlobalPrivTableRow) + return p.loadTable(ctx, sqlLoadGlobalPrivTable, p.decodeGlobalPrivTableRow) } // LoadDBTable loads the mysql.db table from database. func (p *MySQLPrivilege) LoadDBTable(ctx sessionctx.Context) error { - err := p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv from mysql.db order by host, db, user;", p.decodeDBTableRow) + err := p.loadTable(ctx, sqlLoadDBTable, p.decodeDBTableRow) if err != nil { return err } @@ -490,7 +502,7 @@ func (p *MySQLPrivilege) buildDBMap() { // LoadTablesPrivTable loads the mysql.tables_priv table from database. func (p *MySQLPrivilege) LoadTablesPrivTable(ctx sessionctx.Context) error { - err := p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv from mysql.tables_priv", p.decodeTablesPrivTableRow) + err := p.loadTable(ctx, sqlLoadTablePrivTable, p.decodeTablesPrivTableRow) if err != nil { return err } @@ -508,24 +520,22 @@ func (p *MySQLPrivilege) buildTablesPrivMap() { // LoadColumnsPrivTable loads the mysql.columns_priv table from database. func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow) + return p.loadTable(ctx, sqlLoadColumnsPrivTable, p.decodeColumnsPrivTableRow) } // LoadDefaultRoles loads the mysql.columns_priv table from database. func (p *MySQLPrivilege) LoadDefaultRoles(ctx sessionctx.Context) error { - return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from mysql.default_roles", p.decodeDefaultRoleTableRow) + return p.loadTable(ctx, sqlLoadDefaultRoles, p.decodeDefaultRoleTableRow) } func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, decodeTableRow func(chunk.Row, []*ast.ResultField) error) error { ctx := context.Background() - tmp, err := sctx.(sqlexec.SQLExecutor).Execute(ctx, sql) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } - rs := tmp[0] defer terror.Call(rs.Close) - fs := rs.Fields() req := rs.NewChunk() for { From 765c362f974e7c8da2147222c5a924a96516ee2e Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 9 Mar 2021 15:46:55 +0800 Subject: [PATCH 21/85] session, util: update session to use new APIs (#22652) (#22804) --- session/bootstrap.go | 120 +++++++++++++++++++++---------------------- session/session.go | 61 ++++++++-------------- 2 files changed, 82 insertions(+), 99 deletions(-) diff --git a/session/bootstrap.go b/session/bootstrap.go index 19834e428156b..ace9d791bf5d6 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -465,7 +465,7 @@ var ( func checkBootstrapped(s Session) (bool, error) { // Check if system db exists. - _, err := s.Execute(context.Background(), fmt.Sprintf("USE %s;", mysql.SystemDB)) + _, err := s.ExecuteInternal(context.Background(), "USE %n", mysql.SystemDB) if err != nil && infoschema.ErrDatabaseNotExists.NotEqual(err) { logutil.BgLogger().Fatal("check bootstrap error", zap.Error(err)) @@ -491,20 +491,18 @@ func checkBootstrapped(s Session) (bool, error) { // getTiDBVar gets variable value from mysql.tidb table. // Those variables are used by TiDB server. func getTiDBVar(s Session, name string) (sVal string, isNull bool, e error) { - sql := fmt.Sprintf(`SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s"`, - mysql.SystemDB, mysql.TiDBTable, name) ctx := context.Background() - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME= %?`, + mysql.SystemDB, + mysql.TiDBTable, + name, + ) if err != nil { return "", true, errors.Trace(err) } - if len(rs) != 1 { - return "", true, errors.New("Wrong number of Recordset") - } - r := rs[0] - defer terror.Call(r.Close) - req := r.NewChunk() - err = r.Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil || req.NumRows() == 0 { return "", true, errors.Trace(err) } @@ -530,7 +528,7 @@ func upgrade(s Session) { } updateBootstrapVer(s) - _, err = s.Execute(context.Background(), "COMMIT") + _, err = s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second @@ -622,7 +620,7 @@ func upgradeToVer8(s Session, ver int64) { return } // This is a dummy upgrade, it checks whether upgradeToVer7 success, if not, do it again. - if _, err := s.Execute(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` from mysql.user limit 0"); err == nil { + if _, err := s.ExecuteInternal(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` from mysql.user limit 0"); err == nil { return } upgradeToVer7(s, ver) @@ -638,7 +636,7 @@ func upgradeToVer9(s Session, ver int64) { } func doReentrantDDL(s Session, sql string, ignorableErrs ...error) { - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), sql) for _, ignorableErr := range ignorableErrs { if terror.ErrorEqual(err, ignorableErr) { return @@ -664,7 +662,7 @@ func upgradeToVer11(s Session, ver int64) { if ver >= version11 { return } - _, err := s.Execute(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") + _, err := s.ExecuteInternal(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { return @@ -679,21 +677,20 @@ func upgradeToVer12(s Session, ver int64) { return } ctx := context.Background() - _, err := s.Execute(ctx, "BEGIN") + _, err := s.ExecuteInternal(ctx, "BEGIN") terror.MustNil(err) sql := "SELECT HIGH_PRIORITY user, host, password FROM mysql.user WHERE password != ''" - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, sql) if terror.ErrorEqual(err, core.ErrUnknownColumn) { sql := "SELECT HIGH_PRIORITY user, host, authentication_string FROM mysql.user WHERE authentication_string != ''" - rs, err = s.Execute(ctx, sql) + rs, err = s.ExecuteInternal(ctx, sql) } terror.MustNil(err) - r := rs[0] sqls := make([]string, 0, 1) - defer terror.Call(r.Close) - req := r.NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) - err = r.Next(ctx, req) + err = rs.Next(ctx, req) for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { user := row.GetString(0) @@ -705,7 +702,7 @@ func upgradeToVer12(s Session, ver int64) { updateSQL := fmt.Sprintf(`UPDATE HIGH_PRIORITY mysql.user set password = "%s" where user="%s" and host="%s"`, newPass, user, host) sqls = append(sqls, updateSQL) } - err = r.Next(ctx, req) + err = rs.Next(ctx, req) } terror.MustNil(err) @@ -735,7 +732,7 @@ func upgradeToVer13(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -764,7 +761,7 @@ func upgradeToVer14(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -779,7 +776,7 @@ func upgradeToVer15(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGCDeleteRangeTable) + _, err = s.ExecuteInternal(context.Background(), CreateGCDeleteRangeTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer15 error", zap.Error(err)) } @@ -849,9 +846,13 @@ func upgradeToVer23(s Session, ver int64) { // writeSystemTZ writes system timezone info into mysql.tidb func writeSystemTZ(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%s", "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, tidbSystemTZ, timeutil.InferSystemTZ(), timeutil.InferSystemTZ()) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, + mysql.SystemDB, + mysql.TiDBTable, + tidbSystemTZ, + timeutil.InferSystemTZ(), + timeutil.InferSystemTZ(), + ) } // upgradeToVer24 initializes `System` timezone according to docs/design/2018-09-10-adding-tz-env.md @@ -980,7 +981,7 @@ func upgradeToVer38(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGlobalPrivTable) + _, err = s.ExecuteInternal(context.Background(), CreateGlobalPrivTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer38 error", zap.Error(err)) } @@ -1002,9 +1003,9 @@ func writeNewCollationParameter(s Session, flag bool) { if flag { b = varTrue } - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%s', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%s'`, - mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b, + ) } func upgradeToVer40(s Session, ver int64) { @@ -1040,14 +1041,14 @@ func upgradeToVer42(s Session, ver int64) { // Convert statement summary global variables to non-empty values. func writeStmtSummaryVars(s Session) { - sql := fmt.Sprintf("UPDATE %s.%s SET variable_value='%%s' WHERE variable_name='%%s' AND variable_value=''", mysql.SystemDB, mysql.GlobalVariablesTable) + sql := "UPDATE mysql.global_variables SET variable_value= %? WHERE variable_name= %? AND variable_value=''" stmtSummaryConfig := config.GetGlobalConfig().StmtSummary - mustExecute(s, fmt.Sprintf(sql, variable.BoolToIntStr(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary)) - mustExecute(s, fmt.Sprintf(sql, variable.BoolToIntStr(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength)) + mustExecute(s, sql, variable.BoolToIntStr(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary) + mustExecute(s, sql, variable.BoolToIntStr(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery) + mustExecute(s, sql, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval) + mustExecute(s, sql, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize) + mustExecute(s, sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount) + mustExecute(s, sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength) } func upgradeToVer43(s Session, ver int64) { @@ -1128,9 +1129,9 @@ func initBindInfoTable(s Session) { } func insertBuiltinBindInfoRow(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES ("%s", "%s", "mysql", "%s", "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", "%s")`, - bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES (%?, %?, "mysql", %?, "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", %?)`, + bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin, + ) } type bindInfo struct { @@ -1233,17 +1234,17 @@ func updateBindInfo(iter *chunk.Iterator4Chunk, p *parser.Parser, bindMap map[st func writeMemoryQuotaQuery(s Session) { comment := "memory_quota_query is 32GB by default in v3.0.x, 1GB by default in v4.0.x" - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%d', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%d'`, - mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30, + ) } // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%d", "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%d"`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion, + ) } // getBootstrapVersion gets bootstrap version from mysql.tidb table; @@ -1263,7 +1264,7 @@ func doDDLWorks(s Session) { // Create a test database. mustExecute(s, "CREATE DATABASE IF NOT EXISTS test") // Create system db. - mustExecute(s, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysql.SystemDB)) + mustExecute(s, "CREATE DATABASE IF NOT EXISTS %n", mysql.SystemDB) // Create user table. mustExecute(s, CreateUserTable) // Create privilege tables. @@ -1332,14 +1333,13 @@ func doDMLWorks(s Session) { strings.Join(values, ", ")) mustExecute(s, sql) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%s", "Bootstrap flag. Do not delete.") - ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap flag. Do not delete.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue, + ) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%d", "Bootstrap version. Do not delete.")`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap version. Do not delete.")`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, + ) writeSystemTZ(s) @@ -1349,7 +1349,7 @@ func doDMLWorks(s Session) { writeStmtSummaryVars(s) - _, err := s.Execute(context.Background(), "COMMIT") + _, err := s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second logutil.BgLogger().Info("doDMLWorks failed", zap.Error(err), zap.Duration("sleeping time", sleepTime)) @@ -1366,8 +1366,8 @@ func doDMLWorks(s Session) { } } -func mustExecute(s Session, sql string) { - _, err := s.Execute(context.Background(), sql) +func mustExecute(s Session, sql string, args ...interface{}) { + _, err := s.ExecuteInternal(context.Background(), sql, args...) if err != nil { debug.PrintStack() logutil.BgLogger().Fatal("mustExecute error", zap.Error(err)) diff --git a/session/session.go b/session/session.go index d05e914e59e89..1413b50426914 100644 --- a/session/session.go +++ b/session/session.go @@ -933,15 +933,19 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet) ([]c } } -// getExecRet executes restricted sql and the result is one column. +// getTableValue executes restricted sql and the result is one column. // It returns a string value. -func (s *session) getExecRet(ctx sessionctx.Context, sql string) (string, error) { - rows, fields, err := s.ExecRestrictedSQL(sql) +func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { + stmt, err := s.ParseWithParams(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + if err != nil { + return "", err + } + rows, fields, err := s.ExecRestrictedStmt(ctx, stmt) if err != nil { return "", err } if len(rows) == 0 { - return "", executor.ErrResultIsEmpty + return "", errResultIsEmpty } d := rows[0].GetDatum(0, &fields[0].Column.FieldType) value, err := d.ToString() @@ -956,9 +960,11 @@ func (s *session) GetAllSysVars() (map[string]string, error) { if s.Value(sessionctx.Initing) != nil { return nil, nil } - sql := `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %s.%s;` - sql = fmt.Sprintf(sql, mysql.SystemDB, mysql.GlobalVariablesTable) - rows, _, err := s.ExecRestrictedSQL(sql) + stmt, err := s.ParseWithParams(context.TODO(), `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %n.%n`, mysql.SystemDB, mysql.GlobalVariablesTable) + if err != nil { + return nil, err + } + rows, _, err := s.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, err } @@ -979,11 +985,9 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { // When running bootstrap or upgrade, we should not access global storage. return "", nil } - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, - mysql.SystemDB, mysql.GlobalVariablesTable, name) - sysVar, err := s.getExecRet(s, sql) + sysVar, err := s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) if err != nil { - if executor.ErrResultIsEmpty.Equal(err) { + if errResultIsEmpty.Equal(err) { if sv, ok := variable.SysVars[name]; ok { return sv.Value, nil } @@ -1012,9 +1016,11 @@ func (s *session) SetGlobalSysVar(name, value string) error { return err } name = strings.ToLower(name) - sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) - _, _, err = s.ExecRestrictedSQL(sql) + stmt, err := s.ParseWithParams(context.TODO(), "REPLACE %n.%n VALUES (%?, %?)", mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) + if err != nil { + return err + } + _, _, err = s.ExecRestrictedStmt(context.TODO(), stmt) return err } @@ -1814,11 +1820,6 @@ func CreateSessionWithOpt(store kv.Storage, opt *Opt) (Session, error) { return s, nil } -// loadSystemTZ loads systemTZ from mysql.tidb -func loadSystemTZ(se *session) (string, error) { - return loadParameter(se, "system_tz") -} - // loadCollationParameter loads collation parameter from mysql.tidb func loadCollationParameter(se *session) (bool, error) { para, err := loadParameter(se, tidbNewCollationEnabled) @@ -1862,25 +1863,7 @@ var ( // loadParameter loads read-only parameter from mysql.tidb func loadParameter(se *session, name string) (string, error) { - sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'" - rss, errLoad := se.Execute(context.Background(), sql) - if errLoad != nil { - return "", errLoad - } - // the record of mysql.tidb under where condition: variable_name = $name should shall only be one. - defer func() { - if err := rss[0].Close(); err != nil { - logutil.BgLogger().Error("close result set error", zap.Error(err)) - } - }() - req := rss[0].NewChunk() - if err := rss[0].Next(context.Background(), req); err != nil { - return "", err - } - if req.NumRows() == 0 { - return "", errResultIsEmpty - } - return req.GetRow(0).GetString(0), nil + return se.getTableValue(context.TODO(), mysql.TiDBTable, name) } // BootstrapSession runs the first time when the TiDB server start. @@ -1912,7 +1895,7 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, err } // get system tz from mysql.tidb - tz, err := loadSystemTZ(se) + tz, err := se.getTableValue(context.TODO(), mysql.TiDBTable, "system_tz") if err != nil { return nil, err } From 90972702e28df9b524162c271a44ca60bc37c4bb Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Wed, 10 Mar 2021 11:10:55 +0800 Subject: [PATCH 22/85] ddl: migrate part of ddl package code from Execute/ExecRestricted to safe API (2) (#22729) (#22935) --- ddl/delete_range.go | 18 ++++++++++-------- ddl/reorg.go | 8 ++++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ddl/delete_range.go b/ddl/delete_range.go index 45059716a824d..bf06b6392fea6 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -16,8 +16,8 @@ package ddl import ( "context" "encoding/hex" - "fmt" "math" + "strings" "sync" "sync/atomic" @@ -35,7 +35,7 @@ import ( const ( insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` - insertDeleteRangeSQLValue = `("%d", "%d", "%s", "%s", "%d")` + insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` insertDeleteRangeSQL = insertDeleteRangeSQLPrefix + insertDeleteRangeSQLValue delBatchSize = 65536 @@ -350,25 +350,27 @@ func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, end logutil.BgLogger().Info("[ddl] insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64("elementID", elementID)) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql := fmt.Sprintf(insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) return errors.Trace(err) } func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) - sql := insertDeleteRangeSQLPrefix + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + paramsList := make([]interface{}, 0, len(tableIDs)*5) for i, tableID := range tableIDs { startKey := tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) + buf.WriteString(insertDeleteRangeSQLValue) if i != len(tableIDs)-1 { - sql += "," + buf.WriteString(",") } + paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } diff --git a/ddl/reorg.go b/ddl/reorg.go index 1eb68a822a265..2d5a9dd7b4c31 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -198,8 +198,12 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { if !ok { return statistics.PseudoRowCount } - sql := fmt.Sprintf("select table_rows from information_schema.tables where tidb_table_id=%v;", tblInfo.ID) - rows, _, err := executor.ExecRestrictedSQL(sql) + sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" + stmt, err := executor.ParseWithParams(context.Background(), sql, tblInfo.ID) + if err != nil { + return statistics.PseudoRowCount + } + rows, _, err := executor.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return statistics.PseudoRowCount } From 51158bdb3ad65b24645b3b98d7c60c9cbc8189ea Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Wed, 10 Mar 2021 16:28:55 +0800 Subject: [PATCH 23/85] ddl: migrate part of ddl package code from Execute/ExecRestricted to safe API (1) (#22670) (#22929) --- ddl/column.go | 20 +++++++++++---- ddl/util/util.go | 65 +++++++++++++++++++++--------------------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/ddl/column.go b/ddl/column.go index d264f669c1f37..42e5de30efa4f 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -14,6 +14,7 @@ package ddl import ( + "context" "fmt" "math/bits" "strings" @@ -538,16 +539,25 @@ func checkAndApplyNewAutoRandomBits(job *model.Job, t *meta.Meta, tblInfo *model // checkForNullValue ensure there are no null values of the column of this table. // `isDataTruncated` indicates whether the new field and the old field type are the same, in order to be compatible with mysql. func checkForNullValue(ctx sessionctx.Context, isDataTruncated bool, schema, table, newCol model.CIStr, oldCols ...*model.ColumnInfo) error { - colsStr := "" + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + paramsList := make([]interface{}, 0, 2+len(oldCols)) + paramsList = append(paramsList, schema.L, table.L) for i, col := range oldCols { if i == 0 { - colsStr += "`" + col.Name.L + "` is null" + buf.WriteString("%n is null") + paramsList = append(paramsList, col.Name.L) } else { - colsStr += " or `" + col.Name.L + "` is null" + buf.WriteString(" or %n is null") + paramsList = append(paramsList, col.Name.L) } } - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s limit 1;", schema.L, table.L, colsStr) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + buf.WriteString(" limit 1") + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), buf.String(), paramsList...) + if err != nil { + return errors.Trace(err) + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } diff --git a/ddl/util/util.go b/ddl/util/util.go index 69a6d3ced56f7..00c0f0e0aa021 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -14,12 +14,10 @@ package util import ( - "bytes" + "strings" + "context" "encoding/hex" - "fmt" - "strconv" - "github.com/pingcap/errors" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" @@ -32,12 +30,13 @@ import ( const ( deleteRangesTable = `gc_delete_range` doneDeleteRangesTable = `gc_delete_range_done` - loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%s WHERE ts < %v` - recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %d AND element_id = %d` - completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %d AND element_id = %d` - completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %d AND element_id in (%v)` - updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = "%s" WHERE job_id = %d AND element_id = %d AND start_key = "%s"` - deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %d AND element_id = %d` + loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%n WHERE ts < %?` + recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id in (` // + idList + ")" + updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = %? WHERE job_id = %? AND element_id = %? AND start_key = %?` + deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %? AND element_id = %?` + loadGlobalVars = `SELECT HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (%?)` ) // DelRangeTask is for run delete-range command in gc_worker. @@ -62,16 +61,14 @@ func LoadDoneDeleteRanges(ctx sessionctx.Context, safePoint uint64) (ranges []De } func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint uint64) (ranges []DelRangeTask, _ error) { - sql := fmt.Sprintf(loadDeleteRangeSQL, table, safePoint) - rss, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rss) > 0 { - defer terror.Call(rss[0].Close) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), loadDeleteRangeSQL, table, safePoint) + if rs != nil { + defer terror.Call(rs.Close) } if err != nil { return nil, errors.Trace(err) } - rs := rss[0] req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) for { @@ -106,8 +103,7 @@ func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint u // CompleteDeleteRange moves a record from gc_delete_range table to gc_delete_range_done table. // NOTE: This function WILL NOT start and run in a new transaction internally. func CompleteDeleteRange(ctx sessionctx.Context, dr DelRangeTask) error { - sql := fmt.Sprintf(recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) if err != nil { return errors.Trace(err) } @@ -117,29 +113,31 @@ func CompleteDeleteRange(ctx sessionctx.Context, dr DelRangeTask) error { // RemoveFromGCDeleteRange is exported for ddl pkg to use. func RemoveFromGCDeleteRange(ctx sessionctx.Context, jobID, elementID int64) error { - sql := fmt.Sprintf(completeDeleteRangeSQL, jobID, elementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), completeDeleteRangeSQL, jobID, elementID) return errors.Trace(err) } // RemoveMultiFromGCDeleteRange is exported for ddl pkg to use. func RemoveMultiFromGCDeleteRange(ctx sessionctx.Context, jobID int64, elementIDs []int64) error { - var buf bytes.Buffer + var buf strings.Builder + buf.WriteString(completeDeleteMultiRangesSQL) + paramIDs := make([]interface{}, 0, 1+len(elementIDs)) + paramIDs = append(paramIDs, jobID) for i, elementID := range elementIDs { if i > 0 { buf.WriteString(", ") } - buf.WriteString(strconv.FormatInt(elementID, 10)) + buf.WriteString("%?") + paramIDs = append(paramIDs, elementID) } - sql := fmt.Sprintf(completeDeleteMultiRangesSQL, jobID, buf.String()) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + buf.WriteString(")") + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), buf.String(), paramIDs...) return errors.Trace(err) } // DeleteDoneRecord removes a record from gc_delete_range_done table. func DeleteDoneRecord(ctx sessionctx.Context, dr DelRangeTask) error { - sql := fmt.Sprintf(deleteDoneRecordSQL, dr.JobID, dr.ElementID) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), deleteDoneRecordSQL, dr.JobID, dr.ElementID) return errors.Trace(err) } @@ -147,8 +145,7 @@ func DeleteDoneRecord(ctx sessionctx.Context, dr DelRangeTask) error { func UpdateDeleteRange(ctx sessionctx.Context, dr DelRangeTask, newStartKey, oldStartKey kv.Key) error { newStartKeyHex := hex.EncodeToString(newStartKey) oldStartKeyHex := hex.EncodeToString(oldStartKey) - sql := fmt.Sprintf(updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) return errors.Trace(err) } @@ -162,20 +159,14 @@ func LoadDDLVars(ctx sessionctx.Context) error { return LoadGlobalVars(ctx, []string{variable.TiDBDDLErrorCountLimit}) } -const loadGlobalVarsSQL = "select HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (%s)" - // LoadGlobalVars loads global variable from mysql.global_variables. func LoadGlobalVars(ctx sessionctx.Context, varNames []string) error { if sctx, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { - nameList := "" - for i, name := range varNames { - if i > 0 { - nameList += ", " - } - nameList += fmt.Sprintf("'%s'", name) + stmt, err := sctx.ParseWithParams(context.Background(), loadGlobalVars, varNames) + if err != nil { + return errors.Trace(err) } - sql := fmt.Sprintf(loadGlobalVarsSQL, nameList) - rows, _, err := sctx.ExecRestrictedSQL(sql) + rows, _, err := sctx.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } From 122ee4d8fc9cf1de007960d1c7f43781ac43d9e7 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 11 Mar 2021 15:22:55 +0800 Subject: [PATCH 24/85] statistics: fix a case that auto-analyze is triggered outside its time range (#23214) (#23219) --- statistics/handle/update.go | 19 +++++++---- statistics/handle/update_test.go | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index affd2c0ca2457..a17357cef2c5b 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -671,6 +671,11 @@ func TableAnalyzed(tbl *statistics.Table) bool { // "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio" and the current time is // between `start` and `end`. func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRatio float64, start, end, now time.Time) (bool, string) { + // Tests if current time is within the time period. + if !timeutil.WithinDayTimePeriod(start, end, now) { + return false, "" + } + analyzed := TableAnalyzed(tbl) if !analyzed { t := time.Unix(0, oracle.ExtractPhysical(tbl.Version)*int64(time.Millisecond)) @@ -685,8 +690,7 @@ func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRat if float64(tbl.ModifyCount)/float64(tbl.Count) <= autoAnalyzeRatio { return false, "" } - // Tests if current time is within the time period. - return timeutil.WithinDayTimePeriod(start, end, now), fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tbl.Count, autoAnalyzeRatio) + return true, fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tbl.Count, autoAnalyzeRatio) } func (h *Handle) getAutoAnalyzeParameters() map[string]string { @@ -727,14 +731,14 @@ func parseAnalyzePeriod(start, end string) (time.Time, time.Time, error) { } // HandleAutoAnalyze analyzes the newly created table or index. -func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) { +func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { dbs := is.AllSchemaNames() parameters := h.getAutoAnalyzeParameters() autoAnalyzeRatio := parseAutoAnalyzeRatio(parameters[variable.TiDBAutoAnalyzeRatio]) start, end, err := parseAnalyzePeriod(parameters[variable.TiDBAutoAnalyzeStartTime], parameters[variable.TiDBAutoAnalyzeEndTime]) if err != nil { logutil.BgLogger().Error("[stats] parse auto analyze period failed", zap.Error(err)) - return + return false } for _, db := range dbs { tbls := is.SchemaTables(model.NewCIStr(db)) @@ -747,7 +751,9 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) { sql := fmt.Sprintf("analyze table %s", tblName) analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) if analyzed { - return + // analyze one table at a time to let it get the freshest parameters. + // others will be analyzed next round which is just 3s later. + return true } continue } @@ -756,12 +762,13 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) { statsTbl := h.GetPartitionStats(tblInfo, def.ID) analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) if analyzed { - return + return true } continue } } } + return false } func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics.Table, start, end time.Time, ratio float64, sql string) bool { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index 06196cb9f226f..712dc77cf3ca4 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -41,6 +41,26 @@ import ( ) var _ = Suite(&testStatsSuite{}) +var _ = SerialSuites(&testSerialStatsSuite{}) + +type testSerialStatsSuite struct { + store kv.Storage + do *domain.Domain +} + +func (s *testSerialStatsSuite) SetUpSuite(c *C) { + testleak.BeforeTest() + // Add the hook here to avoid data race. + var err error + s.store, s.do, err = newStoreWithBootstrap() + c.Assert(err, IsNil) +} + +func (s *testSerialStatsSuite) TearDownSuite(c *C) { + s.do.Close() + s.store.Close() + testleak.AfterTest(c)() +} type testStatsSuite struct { store kv.Storage @@ -465,6 +485,41 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { c.Assert(hg.Len(), Equals, 3) } +func (s *testSerialStatsSuite) TestAutoAnalyzeOnEmptyTable(c *C) { + defer cleanEnv(c, s.store, s.do) + tk := testkit.NewTestKit(c, s.store) + + oriStart := tk.MustQuery("select @@tidb_auto_analyze_start_time").Rows()[0][0].(string) + oriEnd := tk.MustQuery("select @@tidb_auto_analyze_end_time").Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='%v'", oriStart)) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='%v'", oriEnd)) + }() + + t := time.Now().Add(-1 * time.Minute) + h, m := t.Hour(), t.Minute() + start, end := fmt.Sprintf("%02d:%02d +0000", h, m), fmt.Sprintf("%02d:%02d +0000", h, m) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='%v'", start)) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='%v'", end)) + s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()) + + tk.MustExec("use test") + tk.MustExec("create table t (a int, index idx(a))") + // to pass the stats.Pseudo check in autoAnalyzeTable + tk.MustExec("analyze table t") + // to pass the AutoAnalyzeMinCnt check in autoAnalyzeTable + tk.MustExec("insert into t values (1)" + strings.Repeat(", (1)", int(handle.AutoAnalyzeMinCnt))) + c.Assert(s.do.StatsHandle().DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(s.do.StatsHandle().Update(s.do.InfoSchema()), IsNil) + + // test if it will be limited by the time range + c.Assert(s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()), IsFalse) + + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_start_time='00:00 +0000'")) + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_end_time='23:59 +0000'")) + c.Assert(s.do.StatsHandle().HandleAutoAnalyze(s.do.InfoSchema()), IsTrue) +} + func (s *testStatsSuite) TestAutoUpdatePartition(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) From 5edebf98fd349956459033a611e4c8299be8973a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 12 Mar 2021 14:50:42 +0800 Subject: [PATCH 25/85] planner, sessionctx : Add 'last_plan_from_binding' to help know whether sql's plan is matched with the hints in the binding (#18017) (#21430) --- bindinfo/bind_test.go | 22 ++++++++++++++++++++++ executor/adapter.go | 2 ++ executor/executor.go | 2 ++ executor/set.go | 4 ++++ executor/slow_query.go | 8 ++++++++ executor/slow_query_test.go | 4 +++- infoschema/perfschema/const.go | 1 + infoschema/tables.go | 2 ++ infoschema/tables_test.go | 4 ++-- planner/optimize.go | 10 ++++++++++ sessionctx/variable/session.go | 12 ++++++++++++ sessionctx/variable/session_test.go | 2 ++ sessionctx/variable/sysvar.go | 1 + sessionctx/variable/tidb_vars.go | 4 ++++ sessionctx/variable/varsutil.go | 2 ++ sessionctx/variable/varsutil_test.go | 8 ++++++++ util/stmtsummary/statement_summary.go | 11 +++++++++++ util/stmtsummary/statement_summary_test.go | 2 +- 18 files changed, 97 insertions(+), 4 deletions(-) diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index cbc0a08852034..5daecc922a765 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -1541,6 +1541,28 @@ func (s *testSuite) TestCapturePlanBaselineIgnoreTiFlash(c *C) { c.Assert(rows[0][1], Equals, "SELECT /*+ use_index(@`sel_1` `test`.`t` )*/ * FROM `test`.`t`") } +func (s *testSuite) TestSPMHitInfo(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"), IsTrue) + c.Assert(tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) + + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0")) + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + c.Assert(tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"), IsTrue) + tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id") + tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1")) + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + func (s *testSuite) TestNotEvolvePlanForReadStorageHint(c *C) { tk := testkit.NewTestKit(c, s.store) s.cleanBindingEnv(tk) diff --git a/executor/adapter.go b/executor/adapter.go index 8a7974e35a081..35476e1d22e75 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -921,6 +921,7 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { Prepared: a.isPreparedStmt, HasMoreResults: hasMoreResults, PlanFromCache: sessVars.FoundInPlanCache, + PlanFromBinding: sessVars.FoundInBinding, KVTotal: time.Duration(atomic.LoadInt64(&stmtDetail.WaitKVRespDuration)), PDTotal: time.Duration(atomic.LoadInt64(&stmtDetail.WaitPDRespDuration)), BackoffTotal: time.Duration(atomic.LoadInt64(&stmtDetail.BackoffDuration)), @@ -1113,6 +1114,7 @@ func (a *ExecStmt) SummaryStmt(succ bool) { IsInternal: sessVars.InRestrictedSQL, Succeed: succ, PlanInCache: sessVars.FoundInPlanCache, + PlanInBinding: sessVars.FoundInBinding, ExecRetryCount: a.retryCount, StmtExecDetails: stmtDetail, Prepared: a.isPreparedStmt, diff --git a/executor/executor.go b/executor/executor.go index 852b05a72d702..7192a1ecc45d9 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1772,6 +1772,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars.StmtCtx = sc vars.PrevFoundInPlanCache = vars.FoundInPlanCache vars.FoundInPlanCache = false + vars.PrevFoundInBinding = vars.FoundInBinding + vars.FoundInBinding = false return } diff --git a/executor/set.go b/executor/set.go index 75b6b0db95c60..c1753f3ed1313 100644 --- a/executor/set.go +++ b/executor/set.go @@ -184,6 +184,10 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e sessionVars.StmtCtx.AppendWarning(fmt.Errorf("Set operation for '%s' will not take effect", variable.TiDBFoundInPlanCache)) return nil } + if name == variable.TiDBFoundInBinding { + sessionVars.StmtCtx.AppendWarning(fmt.Errorf("Set operation for '%s' will not take effect", variable.TiDBFoundInBinding)) + return nil + } err = variable.SetSessionSystemVar(sessionVars, name, value) if err != nil { return err diff --git a/executor/slow_query.go b/executor/slow_query.go index 68a5cf84f3a42..84bdeedf691f6 100755 --- a/executor/slow_query.go +++ b/executor/slow_query.go @@ -498,6 +498,7 @@ type slowQueryTuple struct { isInternal bool succ bool planFromCache bool + planFromBinding bool prepared bool kvTotal float64 pdTotal float64 @@ -641,6 +642,8 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, st.succ, err = strconv.ParseBool(value) case variable.SlowLogPlanFromCache: st.planFromCache, err = strconv.ParseBool(value) + case variable.SlowLogPlanFromBinding: + st.planFromBinding, err = strconv.ParseBool(value) case variable.SlowLogPlan: st.plan = value case variable.SlowLogPlanDigest: @@ -750,6 +753,11 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { } else { record = append(record, types.NewIntDatum(0)) } + if st.planFromBinding { + record = append(record, types.NewIntDatum(1)) + } else { + record = append(record, types.NewIntDatum(0)) + } record = append(record, types.NewStringDatum(parsePlan(st.plan))) record = append(record, types.NewStringDatum(st.planDigest)) record = append(record, types.NewStringDatum(st.prevStmt)) diff --git a/executor/slow_query_test.go b/executor/slow_query_test.go index 09aaf9ad77e6e..9d627d26f1687 100644 --- a/executor/slow_query_test.go +++ b/executor/slow_query_test.go @@ -68,6 +68,7 @@ func (s *testExecSuite) TestParseSlowLogPanic(c *C) { # Mem_max: 70724 # Disk_max: 65536 # Plan_from_cache: true +# Plan_from_binding: true # Succ: false # Plan_digest: 60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4 # Prev_stmt: update t set i = 1; @@ -106,6 +107,7 @@ func (s *testExecSuite) TestParseSlowLogFile(c *C) { # Mem_max: 70724 # Disk_max: 65536 # Plan_from_cache: true +# Plan_from_binding: true # Succ: false # Plan_digest: 60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4 # Prev_stmt: update t set i = 1; @@ -133,7 +135,7 @@ select * from t;` `0,0,0,0,0,0,0,0,0,0,0,0,,0,0,0,0,0,0,0.38,0.021,0,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,` + `0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,65536,0,0,0,0,` + `Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2 Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2 Cop_backoff_rpcTiKV_total_times: 200 Cop_backoff_rpcTiKV_total_time: 0.2 Cop_backoff_rpcTiKV_max_time: 0.2 Cop_backoff_rpcTiKV_max_addr: 127.0.0.1 Cop_backoff_rpcTiKV_avg_time: 0.2 Cop_backoff_rpcTiKV_p90_time: 0.2,` + - `0,0,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + + `0,0,1,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + `update t set i = 1;,select * from t;` c.Assert(recordString, Equals, expectRecordString) diff --git a/infoschema/perfschema/const.go b/infoschema/perfschema/const.go index 86a5f1766739f..157ce4c7ae7c9 100644 --- a/infoschema/perfschema/const.go +++ b/infoschema/perfschema/const.go @@ -415,6 +415,7 @@ const tableEventsStatementsSummaryByDigest = "CREATE TABLE if not exists perform "LAST_SEEN timestamp(6) NOT NULL DEFAULT '0000-00-00 00:00:00.000000'," + "PLAN_IN_CACHE bool NOT NULL," + "PLAN_CACHE_HITS bigint unsigned NOT NULL," + + "PLAN_IN_BINDING bool NOT NULL," + "QUANTILE_95 bigint unsigned NOT NULL," + "QUANTILE_99 bigint unsigned NOT NULL," + "QUANTILE_999 bigint unsigned NOT NULL," + diff --git a/infoschema/tables.go b/infoschema/tables.go index e4dd6fa2f7165..845f78877f271 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -773,6 +773,7 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogPrepared, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogSucc, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlanFromCache, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogPlanFromBinding, tp: mysql.TypeTiny, size: 1}, {name: variable.SlowLogPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, {name: variable.SlowLogPlanDigest, tp: mysql.TypeVarchar, size: 128}, {name: variable.SlowLogPrevStmt, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, @@ -1234,6 +1235,7 @@ var tableStatementsSummaryCols = []columnInfo{ {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the last time"}, {name: "PLAN_IN_CACHE", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement hit plan cache"}, {name: "PLAN_CACHE_HITS", tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times these statements hit plan cache"}, + {name: "PLAN_IN_BINDING", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement is matched with the hints in the binding"}, {name: "QUERY_SAMPLE_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled original statement"}, {name: "PREV_SAMPLE_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The previous statement before commit"}, {name: "PLAN_DIGEST", tp: mysql.TypeVarchar, size: 64, comment: "Digest of its execution plan"}, diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 6759efa55f8c1..5e425730a9231 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -638,10 +638,10 @@ func (s *testTableSuite) TestSlowQuery(c *C) { tk.MustExec("set time_zone = '+08:00';") re := tk.MustQuery("select * from information_schema.slow_query") re.Check(testutil.RowsWithSep("|", - "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) tk.MustExec("set time_zone = '+00:00';") re = tk.MustQuery("select * from information_schema.slow_query") - re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) + re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0||0|1|1|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4|update t set i = 2;|select * from t_slim;")) // Test for long query. f, err := os.OpenFile(slowLogFileName, os.O_CREATE|os.O_WRONLY, 0644) diff --git a/planner/optimize.go b/planner/optimize.go index 09d2de310616a..0eefceffc2262 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -124,6 +124,10 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in sctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("sql_select_limit is set, so plan binding is not activated")) return bestPlan, names, nil } + err = setFoundInBinding(sctx, true) + if err != nil { + return nil, nil, err + } bestPlanHint := plannercore.GenHintsFromPhysicalPlan(bestPlan) if len(bindRecord.Bindings) > 0 { orgBinding := bindRecord.Bindings[0] // the first is the original binding @@ -495,6 +499,12 @@ func handleStmtHints(hints []*ast.TableOptimizerHint) (stmtHints stmtctx.StmtHin return } +func setFoundInBinding(sctx sessionctx.Context, opt bool) error { + vars := sctx.GetSessionVars() + err := vars.SetSystemVar(variable.TiDBFoundInBinding, variable.BoolToIntStr(opt)) + return err +} + func init() { plannercore.OptimizeAstNode = Optimize } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index d9c1c44fcca9d..2ec0d56e965c8 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -656,6 +656,10 @@ type SessionVars struct { // PrevFoundInPlanCache indicates whether the last statement was found in plan cache. PrevFoundInPlanCache bool + // FoundInBinding indicates whether the execution plan is matched with the hints in the binding. + FoundInBinding bool + // PrevFoundInBinding indicates whether the last execution plan is matched with the hints in the binding. + PrevFoundInBinding bool // SelectLimit limits the max counts of select statement's output SelectLimit uint64 @@ -760,6 +764,8 @@ func NewSessionVars() *SessionVars { WindowingUseHighPrecision: true, PrevFoundInPlanCache: DefTiDBFoundInPlanCache, FoundInPlanCache: DefTiDBFoundInPlanCache, + PrevFoundInBinding: DefTiDBFoundInBinding, + FoundInBinding: DefTiDBFoundInBinding, SelectLimit: math.MaxUint64, AllowAutoRandExplicitInsert: DefTiDBAllowAutoRandExplicitInsert, EnableAmendPessimisticTxn: DefTiDBEnableAmendPessimisticTxn, @@ -1370,6 +1376,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { config.GetGlobalConfig().CheckMb4ValueInUTF8 = TiDBOptOn(val) case TiDBFoundInPlanCache: s.FoundInPlanCache = TiDBOptOn(val) + case TiDBFoundInBinding: + s.FoundInBinding = TiDBOptOn(val) case SQLSelectLimit: result, err := strconv.ParseUint(val, 10, 64) if err != nil { @@ -1643,6 +1651,8 @@ const ( SlowLogPrepared = "Prepared" // SlowLogPlanFromCache is used to indicate whether this plan is from plan cache. SlowLogPlanFromCache = "Plan_from_cache" + // SlowLogPlanFromBinding is used to indicate whether this plan is matched with the hints in the binding. + SlowLogPlanFromBinding = "Plan_from_binding" // SlowLogHasMoreResults is used to indicate whether this sql has more following results. SlowLogHasMoreResults = "Has_more_results" // SlowLogSucc is used to indicate whether this sql execute successfully. @@ -1695,6 +1705,7 @@ type SlowQueryLogItems struct { Succ bool Prepared bool PlanFromCache bool + PlanFromBinding bool HasMoreResults bool PrevStmt string Plan string @@ -1862,6 +1873,7 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { writeSlowLogItem(&buf, SlowLogPrepared, strconv.FormatBool(logItems.Prepared)) writeSlowLogItem(&buf, SlowLogPlanFromCache, strconv.FormatBool(logItems.PlanFromCache)) + writeSlowLogItem(&buf, SlowLogPlanFromBinding, strconv.FormatBool(logItems.PlanFromBinding)) writeSlowLogItem(&buf, SlowLogHasMoreResults, strconv.FormatBool(logItems.HasMoreResults)) writeSlowLogItem(&buf, SlowLogKVTotal, strconv.FormatFloat(logItems.KVTotal.Seconds(), 'f', -1, 64)) writeSlowLogItem(&buf, SlowLogPDTotal, strconv.FormatFloat(logItems.PDTotal.Seconds(), 'f', -1, 64)) diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index 7ab9927005305..3e13fd0b0ba3e 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -201,6 +201,7 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # Disk_max: 6666 # Prepared: true # Plan_from_cache: true +# Plan_from_binding: true # Has_more_results: true # KV_total: 10 # PD_total: 11 @@ -226,6 +227,7 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { DiskMax: diskMax, Prepared: true, PlanFromCache: true, + PlanFromBinding: true, HasMoreResults: true, KVTotal: 10 * time.Second, PDTotal: 11 * time.Second, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index cef0093d86eb8..5049c650fb48a 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -734,6 +734,7 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)}, {ScopeSession, TiDBFoundInPlanCache, BoolToIntStr(DefTiDBFoundInPlanCache)}, + {ScopeSession, TiDBFoundInBinding, BoolToIntStr(DefTiDBFoundInBinding)}, {ScopeSession, TiDBEnableCollectExecutionInfo, BoolToIntStr(DefTiDBEnableCollectExecutionInfo)}, {ScopeGlobal | ScopeSession, TiDBAllowAutoRandExplicitInsert, boolToOnOff(DefTiDBAllowAutoRandExplicitInsert)}, {ScopeGlobal | ScopeSession, TiDBSlowLogMasking, BoolToIntStr(DefTiDBRedactLog)}, diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index d81cdf775db46..b48920e185f8b 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -181,6 +181,9 @@ const ( // TiDBFoundInPlanCache indicates whether the last statement was found in plan cache TiDBFoundInPlanCache = "last_plan_from_cache" + // TiDBFoundInBinding indicates whether the last statement was matched with the hints in the binding. + TiDBFoundInBinding = "last_plan_from_binding" + // TiDBAllowAutoRandExplicitInsert indicates whether explicit insertion on auto_random column is allowed. TiDBAllowAutoRandExplicitInsert = "allow_auto_random_explicit_insert" ) @@ -525,6 +528,7 @@ const ( DefTiDBMetricSchemaStep = 60 // 60s DefTiDBMetricSchemaRangeDuration = 60 // 60s DefTiDBFoundInPlanCache = false + DefTiDBFoundInBinding = false DefTiDBEnableCollectExecutionInfo = true DefTiDBAllowAutoRandExplicitInsert = false DefTiDBRedactLog = false diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index f5fcabe795211..676dcc2fb455c 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -163,6 +163,8 @@ func GetSessionOnlySysVars(s *SessionVars, key string) (string, bool, error) { return CapturePlanBaseline.GetVal(), true, nil case TiDBFoundInPlanCache: return BoolToIntStr(s.PrevFoundInPlanCache), true, nil + case TiDBFoundInBinding: + return BoolToIntStr(s.PrevFoundInBinding), true, nil case TiDBEnableCollectExecutionInfo: return BoolToIntStr(config.GetGlobalConfig().EnableCollectExecutionInfo), true, nil } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index be47632f29636..536d490a57847 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -87,6 +87,7 @@ func (s *testVarsutilSuite) TestNewSessionVars(c *C) { c.Assert(vars.TiDBOptJoinReorderThreshold, Equals, DefTiDBOptJoinReorderThreshold) c.Assert(vars.EnableFastAnalyze, Equals, DefTiDBUseFastAnalyze) c.Assert(vars.FoundInPlanCache, Equals, DefTiDBFoundInPlanCache) + c.Assert(vars.FoundInBinding, Equals, DefTiDBFoundInBinding) c.Assert(vars.AllowAutoRandExplicitInsert, Equals, DefTiDBAllowAutoRandExplicitInsert) assertFieldsGreaterThanZero(c, reflect.ValueOf(vars.Concurrency)) @@ -456,6 +457,13 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { c.Assert(val, Equals, "0") c.Assert(v.systems[TiDBFoundInPlanCache], Equals, "1") + err = SetSessionSystemVar(v, TiDBFoundInBinding, types.NewStringDatum("1")) + c.Assert(err, IsNil) + val, err = GetSessionSystemVar(v, TiDBFoundInBinding) + c.Assert(err, IsNil) + c.Assert(val, Equals, "0") + c.Assert(v.systems[TiDBFoundInBinding], Equals, "1") + err = SetSessionSystemVar(v, "UnknownVariable", types.NewStringDatum("on")) c.Assert(err, ErrorMatches, ".*]Unknown system variable 'UnknownVariable'") } diff --git a/util/stmtsummary/statement_summary.go b/util/stmtsummary/statement_summary.go index ccd2d8b9b4f70..1250a6a8bf780 100644 --- a/util/stmtsummary/statement_summary.go +++ b/util/stmtsummary/statement_summary.go @@ -183,6 +183,7 @@ type stmtSummaryByDigestElement struct { // plan cache planInCache bool planCacheHits int64 + planInBinding bool // pessimistic execution retry information. execRetryCount uint execRetryTime time.Duration @@ -214,6 +215,7 @@ type StmtExecInfo struct { IsInternal bool Succeed bool PlanInCache bool + PlanInBinding bool ExecRetryCount uint ExecRetryTime time.Duration execdetails.StmtExecDetails @@ -627,6 +629,7 @@ func newStmtSummaryByDigestElement(sei *StmtExecInfo, beginTime int64, intervalS authUsers: make(map[string]struct{}), planInCache: false, planCacheHits: 0, + planInBinding: false, prepared: sei.Prepared, } ssElement.add(sei, intervalSeconds) @@ -782,6 +785,13 @@ func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeco ssElement.planInCache = false } + // SPM + if sei.PlanInBinding { + ssElement.planInBinding = true + } else { + ssElement.planInBinding = false + } + // other ssElement.sumAffectedRows += sei.StmtCtx.AffectedRows() ssElement.sumMem += sei.MemMax @@ -899,6 +909,7 @@ func (ssElement *stmtSummaryByDigestElement) toDatum(ssbd *stmtSummaryByDigest) types.NewTime(types.FromGoTime(ssElement.lastSeen), mysql.TypeTimestamp, 0), ssElement.planInCache, ssElement.planCacheHits, + ssElement.planInBinding, ssElement.sampleSQL, ssElement.prevSQL, ssbd.planDigest, diff --git a/util/stmtsummary/statement_summary_test.go b/util/stmtsummary/statement_summary_test.go index 39df0a45147bb..51adb3d6f22a9 100644 --- a/util/stmtsummary/statement_summary_test.go +++ b/util/stmtsummary/statement_summary_test.go @@ -616,7 +616,7 @@ func (s *testStmtSummarySuite) TestToDatum(c *C) { stmtExecInfo1.ExecDetail.CommitDetail.TxnRetry, stmtExecInfo1.ExecDetail.CommitDetail.TxnRetry, 0, 0, 1, "txnLock:1", stmtExecInfo1.MemMax, stmtExecInfo1.MemMax, stmtExecInfo1.DiskMax, stmtExecInfo1.DiskMax, 0, 0, 0, 0, 0, stmtExecInfo1.StmtCtx.AffectedRows(), - t, t, 0, 0, stmtExecInfo1.OriginalSQL, stmtExecInfo1.PrevSQL, "plan_digest", ""} + t, t, 0, 0, 0, stmtExecInfo1.OriginalSQL, stmtExecInfo1.PrevSQL, "plan_digest", ""} match(c, datums[0], expectedDatum...) datums = s.ssMap.ToHistoryDatum(nil, true) c.Assert(len(datums), Equals, 1) From 8b9ed63fc31521552c2a4618de01c3046d64db01 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 12 Mar 2021 16:56:56 +0800 Subject: [PATCH 26/85] *: fix a bug that collation is not handle for text type (#23045) (#23092) --- expression/integration_test.go | 14 ++++++++++++++ session/bootstrap_test.go | 4 ++-- table/column.go | 4 +--- table/column_test.go | 2 +- types/datum.go | 12 ++---------- util/chunk/row.go | 6 +----- util/rowcodec/decoder.go | 4 +--- util/rowcodec/rowcodec_test.go | 12 +++++++----- 8 files changed, 29 insertions(+), 29 deletions(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index 7e7a51412df8a..fb2bbfbbf1cca 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -7200,6 +7200,20 @@ func (s *testIntegrationSerialSuite) TestIssue18702(c *C) { tk.MustQuery("SELECT * FROM t FORCE INDEX(idx_bc);").Check(testkit.Rows("1 A 10 1", "2 B 20 1")) } +func (s *testIntegrationSerialSuite) TestCollationText(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a TINYTEXT collate UTF8MB4_GENERAL_CI, UNIQUE KEY `a`(`a`(10)));") + tk.MustExec("insert into t (a) values ('A');") + tk.MustQuery("select * from t t1 inner join t t2 on t1.a = t2.a where t1.a = 'A';").Check(testkit.Rows("A A")) + tk.MustExec("update t set a = 'B';") + tk.MustExec("admin check table t;") +} + func (s *testIntegrationSuite) TestIssue18850(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index e587aadbc8f74..2a59620a98d28 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -54,7 +54,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) datums := statistics.RowToDatums(req.GetRow(0), r.Fields()) - match(c, datums, `%`, "root", []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") + match(c, datums, `%`, "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") @@ -159,7 +159,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { c.Assert(req.NumRows() == 0, IsFalse) row := req.GetRow(0) datums := statistics.RowToDatums(row, r.Fields()) - match(c, datums, `%`, "root", []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") + match(c, datums, `%`, "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y") c.Assert(r.Close(), IsNil) mustExecSQL(c, se, "USE test;") diff --git a/table/column.go b/table/column.go index d3fa703266285..29df9267b0cc0 100644 --- a/table/column.go +++ b/table/column.go @@ -604,10 +604,8 @@ func GetZeroValue(col *model.ColumnInfo) types.Datum { } else { d.SetString("", col.Collate) } - case mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: d.SetString("", col.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - d.SetBytes([]byte{}) case mysql.TypeDuration: d.SetMysqlDuration(types.ZeroDuration) case mysql.TypeDate: diff --git a/table/column_test.go b/table/column_test.go index feffb801ebf6a..826c9a8a4f11b 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -183,7 +183,7 @@ func (t *testTableSuite) TestGetZeroValue(c *C) { }, { types.NewFieldType(mysql.TypeBlob), - types.NewBytesDatum([]byte{}), + types.NewStringDatum(""), }, { types.NewFieldType(mysql.TypeDuration), diff --git a/types/datum.go b/types/datum.go index c48e5637ab460..08c989d366a74 100644 --- a/types/datum.go +++ b/types/datum.go @@ -2136,14 +2136,10 @@ func GetMaxValue(ft *FieldType) (max Datum) { max.SetFloat32(float32(GetMaxFloat(ft.Flen, ft.Decimal))) case mysql.TypeDouble: max.SetFloat64(GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: // codec.Encode KindMaxValue, to avoid import circle bytes := []byte{250} max.SetString(string(bytes), ft.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - // codec.Encode KindMaxValue, to avoid import circle - bytes := []byte{250} - max.SetBytes(bytes) case mysql.TypeNewDecimal: max.SetMysqlDecimal(NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) case mysql.TypeDuration: @@ -2171,14 +2167,10 @@ func GetMinValue(ft *FieldType) (min Datum) { min.SetFloat32(float32(-GetMaxFloat(ft.Flen, ft.Decimal))) case mysql.TypeDouble: min.SetFloat64(-GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: // codec.Encode KindMinNotNull, to avoid import circle bytes := []byte{1} min.SetString(string(bytes), ft.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - // codec.Encode KindMinNotNull, to avoid import circle - bytes := []byte{1} - min.SetBytes(bytes) case mysql.TypeNewDecimal: min.SetMysqlDecimal(NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) case mysql.TypeDuration: diff --git a/util/chunk/row.go b/util/chunk/row.go index 993ec9b58b9d1..0951de6803900 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -148,14 +148,10 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { if !r.IsNull(colIdx) { d.SetFloat64(r.GetFloat64(colIdx)) } - case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString: + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: if !r.IsNull(colIdx) { d.SetString(r.GetString(colIdx), tp.Collate) } - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - if !r.IsNull(colIdx) { - d.SetBytes(r.GetBytes(colIdx)) - } case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if !r.IsNull(colIdx) { d.SetMysqlTime(r.GetTime(colIdx)) diff --git a/util/rowcodec/decoder.go b/util/rowcodec/decoder.go index 8352fc8fa2efc..cd8f7438578de 100644 --- a/util/rowcodec/decoder.go +++ b/util/rowcodec/decoder.go @@ -133,10 +133,8 @@ func (decoder *DatumMapDecoder) decodeColDatum(col *ColInfo, colData []byte) (ty return d, err } d.SetFloat64(fVal) - case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString: + case mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: d.SetString(string(colData), col.Collate) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - d.SetBytes(colData) case mysql.TypeNewDecimal: _, dec, precision, frac, err := codec.DecodeDecimal(colData) if err != nil { diff --git a/util/rowcodec/rowcodec_test.go b/util/rowcodec/rowcodec_test.go index 04b1c3455f5e1..dbe0c3568b8db 100644 --- a/util/rowcodec/rowcodec_test.go +++ b/util/rowcodec/rowcodec_test.go @@ -308,6 +308,8 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { c.Assert(len(remain), Equals, 0) if d.Kind() == types.KindMysqlDecimal { c.Assert(d.GetMysqlDecimal(), DeepEquals, t.bt.GetMysqlDecimal()) + } else if d.Kind() == types.KindBytes { + c.Assert(d.GetBytes(), DeepEquals, t.bt.GetBytes()) } else { c.Assert(d, DeepEquals, t.bt) } @@ -341,9 +343,9 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { }, { 24, - types.NewFieldType(mysql.TypeBlob), - types.NewBytesDatum([]byte("abc")), - types.NewBytesDatum([]byte("abc")), + types.NewFieldTypeWithCollation(mysql.TypeBlob, mysql.DefaultCollationName, types.UnspecifiedLength), + types.NewStringDatum("abc"), + types.NewStringDatum("abc"), nil, false, }, @@ -470,8 +472,8 @@ func (s *testSuite) TestTypesNewRowCodec(c *C) { testData[0].id = 1 // test large data - testData[3].dt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1))) - testData[3].bt = types.NewBytesDatum([]byte(strings.Repeat("a", math.MaxUint16+1))) + testData[3].dt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) + testData[3].bt = types.NewStringDatum(strings.Repeat("a", math.MaxUint16+1)) encodeAndDecode(c, testData) } From e579d1cde98c4971ea4bca966a2eb7ccde75445e Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 12 Mar 2021 19:36:56 +0800 Subject: [PATCH 27/85] planner: fixed a bug that prevented SPM from taking effect (#23197) (#23209) --- bindinfo/bind_test.go | 14 ++++++++++++++ planner/optimize.go | 14 ++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index 5daecc922a765..d0df1d6f54591 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -1865,3 +1865,17 @@ func (s *testSuite) TestCaptureWithZeroSlowLogThreshold(c *C) { c.Assert(len(rows), Equals, 1) c.Assert(rows[0][0], Equals, "select * from test . t") } + +func (s *testSuite) TestSPMWithoutUseDatabase(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk1 := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + s.cleanBindingEnv(tk1) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, key(a))") + tk.MustExec("create global binding for select * from t using select * from t force index(a)") + + err := tk1.ExecToErr("select * from t") + c.Assert(err, ErrorMatches, "*No database selected") +} diff --git a/planner/optimize.go b/planner/optimize.go index 0eefceffc2262..bf08b87813124 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -273,12 +273,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) switch x.Stmt.(type) { case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt: plannercore.EraseLastSemicolon(x) - var normalizeExplainSQL string - if specifiledDB != "" { - normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB)) - } else { - normalizeExplainSQL = parser.Normalize(x.Text()) - } + normalizeExplainSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB)) idx := int(0) switch n := x.Stmt.(type) { case *ast.SelectStmt: @@ -308,12 +303,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) if len(x.Text()) == 0 { return x, "", "" } - var normalizedSQL, hash string - if specifiledDB != "" { - normalizedSQL, hash = parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB)) - } else { - normalizedSQL, hash = parser.NormalizeDigest(x.Text()) - } + normalizedSQL, hash := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB)) return x, normalizedSQL, hash } return nil, "", "" From b848a5d34fc725907bbfcb3221ab3081b4962625 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 15 Mar 2021 10:44:55 +0800 Subject: [PATCH 28/85] tikv: drop store's regions when resolve store with tombstone status (#22909) (#23071) --- store/mockstore/mocktikv/cluster.go | 18 ++++++++++++++---- store/mockstore/mocktikv/rpc.go | 14 ++++++++------ store/tikv/region_cache.go | 2 +- store/tikv/region_cache_test.go | 27 +++++++++++++++++++++++++++ store/tikv/region_request.go | 2 +- 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/store/mockstore/mocktikv/cluster.go b/store/mockstore/mocktikv/cluster.go index cfc1f09e5405a..8eeb9f676761e 100644 --- a/store/mockstore/mocktikv/cluster.go +++ b/store/mockstore/mocktikv/cluster.go @@ -178,19 +178,20 @@ func (c *Cluster) GetStoreByAddr(addr string) *metapb.Store { } // GetAndCheckStoreByAddr checks and returns a Store's meta by an addr -func (c *Cluster) GetAndCheckStoreByAddr(addr string) (*metapb.Store, error) { +func (c *Cluster) GetAndCheckStoreByAddr(addr string) (ss []*metapb.Store, err error) { c.RLock() defer c.RUnlock() for _, s := range c.stores { if s.cancel { - return nil, context.Canceled + err = context.Canceled + return } if s.meta.GetAddress() == addr { - return proto.Clone(s.meta).(*metapb.Store), nil + ss = append(ss, proto.Clone(s.meta).(*metapb.Store)) } } - return nil, nil + return } // AddStore add a new Store to the cluster. @@ -209,6 +210,15 @@ func (c *Cluster) RemoveStore(storeID uint64) { delete(c.stores, storeID) } +// MarkTombstone marks store as tombstone. +func (c *Cluster) MarkTombstone(storeID uint64) { + c.Lock() + defer c.Unlock() + nm := *c.stores[storeID].meta + nm.State = metapb.StoreState_Tombstone + c.stores[storeID].meta = &nm +} + // UpdateStoreAddr updates store address for cluster. func (c *Cluster) UpdateStoreAddr(storeID uint64, addr string, labels ...*metapb.StoreLabel) { c.Lock() diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go index 8d0fb5ce39c98..290798cd087ad 100644 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -740,18 +740,20 @@ func NewRPCClient(cluster *Cluster, mvccStore MVCCStore) *RPCClient { } func (c *RPCClient) getAndCheckStoreByAddr(addr string) (*metapb.Store, error) { - store, err := c.Cluster.GetAndCheckStoreByAddr(addr) + stores, err := c.Cluster.GetAndCheckStoreByAddr(addr) if err != nil { return nil, err } - if store == nil { + if len(stores) == 0 { return nil, errors.New("connect fail") } - if store.GetState() == metapb.StoreState_Offline || - store.GetState() == metapb.StoreState_Tombstone { - return nil, errors.New("connection refused") + for _, store := range stores { + if store.GetState() != metapb.StoreState_Offline && + store.GetState() != metapb.StoreState_Tombstone { + return store, nil + } } - return store, nil + return nil, errors.New("connection refused") } func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, error) { diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 456a56ae347b7..e629d81a5a0a8 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -1522,7 +1522,7 @@ func (s *Store) reResolve(c *RegionCache) { // we cannot do backoff in reResolve loop but try check other store and wait tick. return } - if store == nil { + if store == nil || store.State == metapb.StoreState_Tombstone { // store has be removed in PD, we should invalidate all regions using those store. logutil.BgLogger().Info("invalidate regions in removed store", zap.Uint64("store", s.storeID), zap.String("add", s.addr)) diff --git a/store/tikv/region_cache_test.go b/store/tikv/region_cache_test.go index c2f2d6e9d7cf3..b2dc597524b41 100644 --- a/store/tikv/region_cache_test.go +++ b/store/tikv/region_cache_test.go @@ -833,6 +833,33 @@ func (s *testRegionCacheSuite) TestReplaceNewAddrAndOldOfflineImmediately(c *C) c.Assert(getVal, BytesEquals, testValue) } +func (s *testRegionCacheSuite) TestReplaceStore(c *C) { + mvccStore := mocktikv.MustNewMVCCStore() + defer mvccStore.Close() + + client := &RawKVClient{ + clusterID: 0, + regionCache: NewRegionCache(mocktikv.NewPDClient(s.cluster)), + rpcClient: mocktikv.NewRPCClient(s.cluster, mvccStore), + } + defer client.Close() + testKey := []byte("test_key") + testValue := []byte("test_value") + err := client.Put(testKey, testValue) + c.Assert(err, IsNil) + + s.cluster.MarkTombstone(s.store1) + store3 := s.cluster.AllocID() + peer3 := s.cluster.AllocID() + s.cluster.AddStore(store3, s.storeAddr(s.store1)) + s.cluster.AddPeer(s.region1, store3, peer3) + s.cluster.RemovePeer(s.region1, s.peer1) + s.cluster.ChangeLeader(s.region1, peer3) + + err = client.Put(testKey, testValue) + c.Assert(err, IsNil) +} + func (s *testRegionCacheSuite) TestListRegionIDsInCache(c *C) { // ['' - 'm' - 'z'] region2 := s.cluster.AllocID() diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index f1523456427b2..85cf978c85454 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -604,7 +604,7 @@ func (s *RegionRequestSender) onRegionError(bo *Backoffer, ctx *RPCContext, seed if storeNotMatch := regionErr.GetStoreNotMatch(); storeNotMatch != nil { // store not match - logutil.BgLogger().Warn("tikv reports `StoreNotMatch` retry later", + logutil.BgLogger().Debug("tikv reports `StoreNotMatch` retry later", zap.Stringer("storeNotMatch", storeNotMatch), zap.Stringer("ctx", ctx)) ctx.Store.markNeedCheck(s.regionCache.notifyCheckCh) From 866c24b7ab7a736a7a3ddddaf9cfe2136f51fdc1 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 15 Mar 2021 14:55:59 +0800 Subject: [PATCH 29/85] executor: fix cast function will ignore tht error for point-get key construction (#22869) (#23211) --- executor/point_get.go | 5 ++++- executor/point_get_test.go | 19 +++++++++++++++++++ table/column.go | 4 ++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/executor/point_get.go b/executor/point_get.go index 0fdeca2de81fe..ae8bbcfbb37d5 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -345,8 +345,11 @@ func encodeIndexKey(e *baseExecutor, tblInfo *model.TableInfo, idxInfo *model.In str, err = idxVals[i].ToString() idxVals[i].SetString(str, colInfo.FieldType.Collate) } else { + // If a truncated error or an overflow error is thrown when converting the type of `idxVal[i]` to + // the type of `colInfo`, the `idxVal` does not exist in the `idxInfo` for sure. idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, true, false) - if types.ErrOverflow.Equal(err) { + if types.ErrOverflow.Equal(err) || types.ErrDataTooLong.Equal(err) || + types.ErrTruncated.Equal(err) || types.ErrTruncatedWrongVal.Equal(err) { return nil, false, kv.ErrNotExist } } diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 0269a1f0d65d5..b70677c17be72 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -125,6 +125,25 @@ func (s *testPointGetSuite) TestPointGetOverflow(c *C) { tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=127").Check(testkit.Rows("127")) } +// Close issue #22839 +func (s *testPointGetSuite) TestPointGetDataTooLong(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists PK_1389;") + tk.MustExec("CREATE TABLE `PK_1389` ( " + + " `COL1` bit(1) NOT NULL," + + " `COL2` varchar(20) DEFAULT NULL," + + " `COL3` datetime DEFAULT NULL," + + " `COL4` bigint(20) DEFAULT NULL," + + " `COL5` float DEFAULT NULL," + + " PRIMARY KEY (`COL1`)" + + ");") + tk.MustExec("insert into PK_1389 values(0, \"皟钹糁泅埞礰喾皑杏灚暋蛨歜檈瓗跾咸滐梀揉\", \"7701-12-27 23:58:43\", 4806951672419474695, -1.55652e38);") + tk.MustQuery("select count(1) from PK_1389 where col1 = 0x30;").Check(testkit.Rows("0")) + tk.MustQuery("select count(1) from PK_1389 where col1 in ( 0x30);").Check(testkit.Rows("0")) + tk.MustExec("drop table if exists PK_1389;") +} + func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test;`) diff --git a/table/column.go b/table/column.go index 29df9267b0cc0..72fedda4fd796 100644 --- a/table/column.go +++ b/table/column.go @@ -248,11 +248,11 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty // Set it to true only in FillVirtualColumnValue and UnionScanExec.Next() // If the handle of err is changed latter, the behavior of forceIgnoreTruncate also need to change. // TODO: change the third arg to TypeField. Not pass ColumnInfo. -func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnOverflow, forceIgnoreTruncate bool) (casted types.Datum, err error) { +func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnErr, forceIgnoreTruncate bool) (casted types.Datum, err error) { sc := ctx.GetSessionVars().StmtCtx casted, err = val.ConvertTo(sc, &col.FieldType) // TODO: make sure all truncate errors are handled by ConvertTo. - if types.ErrOverflow.Equal(err) && returnOverflow { + if returnErr && err != nil { return casted, err } if err != nil && types.ErrTruncated.Equal(err) && col.Tp != mysql.TypeSet && col.Tp != mysql.TypeEnum { From 430a81cf240b15e46269eae71c6a437c848e4807 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 16 Mar 2021 05:42:55 +0800 Subject: [PATCH 30/85] bindinfo: use new sql apis (#22653) (#22733) --- bindinfo/handle.go | 121 ++++++++++++++++++++++----------------------- 1 file changed, 59 insertions(+), 62 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 146eb0c96fa0c..2a2355c7c2014 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/parser/format" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -123,17 +122,21 @@ func NewBindHandle(ctx sessionctx.Context) *BindHandle { func (h *BindHandle) Update(fullLoad bool) (err error) { h.bindInfo.Lock() lastUpdateTime := h.bindInfo.lastUpdateTime + updateTime := lastUpdateTime.String() + if fullLoad { + updateTime = "0000-00-00 00:00:00" + } - sql := "select original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source from mysql.bind_info" - if !fullLoad { - sql += " where update_time > \"" + lastUpdateTime.String() + "\"" + exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source + FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time`, updateTime) + if err != nil { + return err } - // We need to apply the updates by order, wrong apply order of same original sql may cause inconsistent state. - sql += " order by update_time" - // No need to acquire the session context lock for ExecRestrictedSQL, it + // No need to acquire the session context lock for ExecRestrictedStmt, it // uses another background session. - rows, _, err := h.sctx.Context.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) if err != nil { h.bindInfo.Unlock() return err @@ -215,7 +218,7 @@ func (h *BindHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRecor return err } // Binding recreation should physically delete previous bindings. - _, err = exec.ExecuteInternal(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db, "")) + _, err = exec.ExecuteInternal(context.TODO(), `DELETE FROM mysql.bind_info WHERE original_sql = %?`, record.OriginalSQL) if err != nil { return err } @@ -227,7 +230,17 @@ func (h *BindHandle) CreateBindRecord(sctx sessionctx.Context, record *BindRecor record.Bindings[i].UpdateTime = now // Insert the BindRecord to the storage. - _, err = exec.ExecuteInternal(context.TODO(), h.insertBindInfoSQL(record.OriginalSQL, record.Db, record.Bindings[i])) + _, err = exec.ExecuteInternal(context.TODO(), `INSERT INTO mysql.bind_info VALUES (%?,%?, %?, %?, %?, %?, %?, %?, %?)`, + record.OriginalSQL, + record.Bindings[i].BindSQL, + record.Db, + record.Bindings[i].Status, + record.Bindings[i].CreateTime.String(), + record.Bindings[i].UpdateTime.String(), + record.Bindings[i].Charset, + record.Bindings[i].Collation, + record.Bindings[i].Source, + ) if err != nil { return err } @@ -289,7 +302,7 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, record *BindRecord) return err } if duplicateBinding != nil { - _, err = exec.ExecuteInternal(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db, duplicateBinding.BindSQL)) + _, err = exec.ExecuteInternal(context.TODO(), `DELETE FROM mysql.bind_info WHERE original_sql = %? AND bind_sql = %?`, record.OriginalSQL, duplicateBinding.BindSQL) if err != nil { return err } @@ -305,7 +318,17 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, record *BindRecord) record.Bindings[i].UpdateTime = now // Insert the BindRecord to the storage. - _, err = exec.ExecuteInternal(context.TODO(), h.insertBindInfoSQL(record.OriginalSQL, record.Db, record.Bindings[i])) + _, err = exec.ExecuteInternal(context.TODO(), `INSERT INTO mysql.bind_info VALUES (%?, %?, %?, %?, %?, %?, %?, %?, %?)`, + record.OriginalSQL, + record.Bindings[i].BindSQL, + record.Db, + record.Bindings[i].Status, + record.Bindings[i].CreateTime.String(), + record.Bindings[i].UpdateTime.String(), + record.Bindings[i].Charset, + record.Bindings[i].Collation, + record.Bindings[i].Source, + ) if err != nil { return err } @@ -349,17 +372,19 @@ func (h *BindHandle) DropBindRecord(originalSQL, db string, binding *Binding) (e // Lock mysql.bind_info to synchronize with CreateBindRecord / AddBindRecord / DropBindRecord on other tidb instances. if err = h.lockBindInfoTable(); err != nil { - return + return err } - updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) + updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3).String() - bindSQL := "" - if binding != nil { - bindSQL = binding.BindSQL + if binding == nil { + _, err = exec.ExecuteInternal(context.TODO(), `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %?`, + deleted, updateTs, originalSQL, updateTs) + } else { + _, err = exec.ExecuteInternal(context.TODO(), `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %? AND bind_sql = %?`, + deleted, updateTs, originalSQL, updateTs, binding.BindSQL) } - _, err = exec.ExecuteInternal(context.TODO(), h.logicalDeleteBindInfoSQL(originalSQL, db, updateTs, bindSQL)) deleteRows = int(h.sctx.Context.GetSessionVars().StmtCtx.AffectedRows()) return err } @@ -575,49 +600,13 @@ func (c cache) getBindRecord(hash, normdOrigSQL, db string) *BindRecord { return nil } -func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db, bindSQL string) string { - sql := fmt.Sprintf( - `DELETE FROM mysql.bind_info WHERE original_sql=%s`, - expression.Quote(normdOrigSQL), - ) - if bindSQL == "" { - return sql - } - return sql + fmt.Sprintf(` and bind_sql = %s`, expression.Quote(bindSQL)) -} - -func (h *BindHandle) insertBindInfoSQL(orignalSQL string, db string, info Binding) string { - return fmt.Sprintf(`INSERT INTO mysql.bind_info VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)`, - expression.Quote(orignalSQL), - expression.Quote(info.BindSQL), - expression.Quote(db), - expression.Quote(info.Status), - expression.Quote(info.CreateTime.String()), - expression.Quote(info.UpdateTime.String()), - expression.Quote(info.Charset), - expression.Quote(info.Collation), - expression.Quote(info.Source), - ) -} - // LockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction. func (h *BindHandle) LockBindInfoSQL() string { - return fmt.Sprintf("UPDATE mysql.bind_info SET source=%s WHERE original_sql=%s", - expression.Quote(Builtin), - expression.Quote(BuiltinPseudoSQL4BindLock)) -} - -func (h *BindHandle) logicalDeleteBindInfoSQL(originalSQL, db string, updateTs types.Time, bindingSQL string) string { - updateTsStr := updateTs.String() - sql := fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and update_time<%s`, - expression.Quote(deleted), - expression.Quote(updateTsStr), - expression.Quote(originalSQL), - expression.Quote(updateTsStr)) - if bindingSQL == "" { - return sql + sql, err := sqlexec.EscapeSQL("UPDATE mysql.bind_info SET source= %? WHERE original_sql= %?", Builtin, BuiltinPseudoSQL4BindLock) + if err != nil { + return "" } - return sql + fmt.Sprintf(` and bind_sql = %s`, expression.Quote(bindingSQL)) + return sql } // CaptureBaselines is used to automatically capture plan baselines. @@ -764,9 +753,17 @@ func (h *BindHandle) SaveEvolveTasksToStore() { } func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) { - sql := fmt.Sprintf("select variable_name, variable_value from mysql.global_variables where variable_name in ('%s', '%s', '%s')", - variable.TiDBEvolvePlanTaskMaxTime, variable.TiDBEvolvePlanTaskStartTime, variable.TiDBEvolvePlanTaskEndTime) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams( + context.TODO(), + "SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)", + variable.TiDBEvolvePlanTaskMaxTime, + variable.TiDBEvolvePlanTaskStartTime, + variable.TiDBEvolvePlanTaskEndTime, + ) + if err != nil { + return 0, time.Time{}, time.Time{}, err + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return 0, time.Time{}, time.Time{}, err } @@ -837,7 +834,7 @@ func (h *BindHandle) getOnePendingVerifyJob() (string, string, Binding) { func (h *BindHandle) getRunningDuration(sctx sessionctx.Context, db, sql string, maxTime time.Duration) (time.Duration, error) { ctx := context.TODO() if db != "" { - _, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, fmt.Sprintf("use `%s`", db)) + _, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "use %n", db) if err != nil { return 0, err } From d4750fe0ab799e395e3c537f549231b20a63d3e2 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:42:55 +0800 Subject: [PATCH 31/85] statistics: refactor the statistics package use the RestrictedSQLExecutor API (#22636) (#22961) (#23225) --- domain/domain.go | 5 +- server/statistics_handler.go | 5 +- statistics/handle/bootstrap.go | 44 +++----- statistics/handle/ddl.go | 70 ++++++------ statistics/handle/dump.go | 21 +++- statistics/handle/gc.go | 67 +++++++---- statistics/handle/handle.go | 187 ++++++++++++++++++------------- statistics/handle/handle_test.go | 10 +- statistics/handle/update.go | 64 +++++------ types/datum.go | 5 +- 10 files changed, 262 insertions(+), 216 deletions(-) diff --git a/domain/domain.go b/domain/domain.go index 1742102cc9445..c581ea037fcc7 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1015,7 +1015,8 @@ func (do *Domain) StatsHandle() *handle.Handle { // CreateStatsHandle is used only for test. func (do *Domain) CreateStatsHandle(ctx sessionctx.Context) { - atomic.StorePointer(&do.statsHandle, unsafe.Pointer(handle.NewHandle(ctx, do.statsLease))) + h := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool) + atomic.StorePointer(&do.statsHandle, unsafe.Pointer(h)) } // StatsUpdating checks if the stats worker is updating. @@ -1040,7 +1041,7 @@ var RunAutoAnalyze = true // It should be called only once in BootstrapSession. func (do *Domain) UpdateTableStatsLoop(ctx sessionctx.Context) error { ctx.GetSessionVars().InRestrictedSQL = true - statsHandle := handle.NewHandle(ctx, do.statsLease) + statsHandle := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool) atomic.StorePointer(&do.statsHandle, unsafe.Pointer(statsHandle)) do.ddl.RegisterEventCh(statsHandle.DDLEventCh()) // Negative stats lease indicates that it is in test, it does not need update. diff --git a/server/statistics_handler.go b/server/statistics_handler.go index a40e1b19b321f..ecb246e46cf48 100644 --- a/server/statistics_handler.go +++ b/server/statistics_handler.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/gcutil" - "github.com/pingcap/tidb/util/sqlexec" ) // StatsHandler is the handler for dumping statistics. @@ -122,9 +121,7 @@ func (sh StatsHistoryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request writeError(w, err) return } - se.GetSessionVars().SnapshotInfoschema, se.GetSessionVars().SnapshotTS = is, snapshot - historyStatsExec := se.(sqlexec.RestrictedSQLExecutor) - js, err := h.DumpStatsToJSON(params[pDBName], tbl.Meta(), historyStatsExec) + js, err := h.DumpStatsToJSONBySnapshot(params[pDBName], tbl.Meta(), snapshot) if err != nil { writeError(w, err) } else { diff --git a/statistics/handle/bootstrap.go b/statistics/handle/bootstrap.go index 4bc8257537eda..bc87750400f2a 100644 --- a/statistics/handle/bootstrap.go +++ b/statistics/handle/bootstrap.go @@ -60,18 +60,16 @@ func (h *Handle) initStatsMeta4Chunk(is infoschema.InfoSchema, cache *statsCache func (h *Handle) initStatsMeta(is infoschema.InfoSchema) (statsCache, error) { sql := "select HIGH_PRIORITY version, table_id, modify_count, count from mysql.stats_meta" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return statsCache{}, errors.Trace(err) } + defer terror.Call(rc.Close) tables := statsCache{tables: make(map[int64]*statistics.Table)} - req := rc[0].NewChunk() + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return statsCache{}, errors.Trace(err) } @@ -147,17 +145,15 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache *stat func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -187,17 +183,15 @@ func (h *Handle) initStatsTopN4Chunk(cache *statsCache, iter *chunk.Iterator4Chu func (h *Handle) initStatsTopN(cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -257,17 +251,15 @@ func initStatsBuckets4Chunk(ctx sessionctx.Context, cache *statsCache, iter *chu func (h *Handle) initStatsBuckets(cache *statsCache) error { sql := "select HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound from mysql.stats_buckets order by table_id, is_index, hist_id, bucket_id" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql) if err != nil { return errors.Trace(err) } - req := rc[0].NewChunk() + defer terror.Call(rc.Close) + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) for { - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -300,13 +292,13 @@ func (h *Handle) initStatsBuckets(cache *statsCache) error { func (h *Handle) InitStats(is infoschema.InfoSchema) (err error) { h.mu.Lock() defer func() { - _, err1 := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "commit") + _, err1 := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") if err == nil && err1 != nil { err = err1 } h.mu.Unlock() }() - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "begin") + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") if err != nil { return err } diff --git a/statistics/handle/ddl.go b/statistics/handle/ddl.go index 127608edb7bb8..7a1770d592be0 100644 --- a/statistics/handle/ddl.go +++ b/statistics/handle/ddl.go @@ -15,7 +15,6 @@ package handle import ( "context" - "fmt" "github.com/pingcap/errors" "github.com/pingcap/parser/model" @@ -75,28 +74,34 @@ func (h *Handle) DDLEventCh() chan *util.Event { func (h *Handle) insertTableStats2KV(info *model.TableInfo, physicalID int64) (err error) { h.mu.Lock() defer h.mu.Unlock() + ctx := context.Background() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { return errors.Trace(err) } startTS := txn.StartTS() - sqls := make([]string, 0, 1+len(info.Columns)+len(info.Indices)) - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_meta (version, table_id) values(%d, %d)", startTS, physicalID)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_meta (version, table_id) values(%?, %?)", startTS, physicalID); err != nil { + return err + } for _, col := range info.Columns { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%d, 0, %d, 0, %d)", physicalID, col.ID, startTS)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%?, 0, %?, 0, %?)", physicalID, col.ID, startTS); err != nil { + return err + } } for _, idx := range info.Indices { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%d, 1, %d, 0, %d)", physicalID, idx.ID, startTS)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version) values(%?, 1, %?, 0, %?)", physicalID, idx.ID, startTS); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) + return nil } // insertColStats2KV insert a record to stats_histograms with distinct_count 1 and insert a bucket to stats_buckets with default value. @@ -105,13 +110,14 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) h.mu.Lock() defer h.mu.Unlock() + ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { @@ -119,24 +125,21 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) } startTS := txn.StartTS() // First of all, we update the version. - _, err = exec.Execute(context.Background(), fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %?", startTS, physicalID) if err != nil { return } - ctx := context.TODO() // If we didn't update anything by last SQL, it means the stats of this table does not exist. if h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() > 0 { // By this step we can get the count of this table, then we can sure the count and repeats of bucket. - var rs []sqlexec.RecordSet - rs, err = exec.Execute(ctx, fmt.Sprintf("select count from mysql.stats_meta where table_id = %d", physicalID)) - if len(rs) > 0 { - defer terror.Call(rs[0].Close) - } + var rs sqlexec.RecordSet + rs, err = exec.ExecuteInternal(ctx, "select count from mysql.stats_meta where table_id = %?", physicalID) if err != nil { return } - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil { return } @@ -146,21 +149,26 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) if err != nil { return } - sqls := make([]string, 0, 1) if value.IsNull() { // If the adding column has default value null, all the existing rows have null value on the newly added column. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, null_count) values (%d, %d, 0, %d, 0, %d)", startTS, physicalID, colInfo.ID, count)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, null_count) values (%?, %?, 0, %?, 0, %?)", startTS, physicalID, colInfo.ID, count); err != nil { + return err + } } else { // If this stats exists, we insert histogram meta first, the distinct_count will always be one. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, tot_col_size) values (%d, %d, 0, %d, 1, %d)", startTS, physicalID, colInfo.ID, int64(len(value.GetBytes()))*count)) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, tot_col_size) values (%?, %?, 0, %?, 1, %?)", startTS, physicalID, colInfo.ID, int64(len(value.GetBytes()))*count); err != nil { + return err + } + value, err = value.ConvertTo(h.mu.ctx.GetSessionVars().StmtCtx, types.NewFieldType(mysql.TypeBlob)) if err != nil { return } // There must be only one bucket for this new column and the value is the default value. - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_buckets (table_id, is_index, hist_id, bucket_id, repeats, count, lower_bound, upper_bound) values (%d, 0, %d, 0, %d, %d, X'%X', X'%X')", physicalID, colInfo.ID, count, count, value.GetBytes(), value.GetBytes())) + if _, err := exec.ExecuteInternal(ctx, "insert into mysql.stats_buckets (table_id, is_index, hist_id, bucket_id, repeats, count, lower_bound, upper_bound) values (%?, 0, %?, 0, %?, %?, %?, %?)", physicalID, colInfo.ID, count, count, value.GetBytes(), value.GetBytes()); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) } return } @@ -168,20 +176,10 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) // finishTransaction will execute `commit` when error is nil, otherwise `rollback`. func finishTransaction(ctx context.Context, exec sqlexec.SQLExecutor, err error) error { if err == nil { - _, err = exec.Execute(ctx, "commit") + _, err = exec.ExecuteInternal(ctx, "commit") } else { - _, err1 := exec.Execute(ctx, "rollback") + _, err1 := exec.ExecuteInternal(ctx, "rollback") terror.Log(errors.Trace(err1)) } return errors.Trace(err) } - -func execSQLs(ctx context.Context, exec sqlexec.SQLExecutor, sqls []string) error { - for _, sql := range sqls { - _, err := exec.Execute(ctx, sql) - if err != nil { - return err - } - } - return nil -} diff --git a/statistics/handle/dump.go b/statistics/handle/dump.go index 16295569d76c0..33ebfdd92fb94 100644 --- a/statistics/handle/dump.go +++ b/statistics/handle/dump.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -63,9 +64,19 @@ func dumpJSONCol(hist *statistics.Histogram, CMSketch *statistics.CMSketch) *jso // DumpStatsToJSON dumps statistic to json. func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, historyStatsExec sqlexec.RestrictedSQLExecutor) (*JSONTable, error) { + var snapshot uint64 + if historyStatsExec != nil { + sctx := historyStatsExec.(sessionctx.Context) + snapshot = sctx.GetSessionVars().SnapshotTS + } + return h.DumpStatsToJSONBySnapshot(dbName, tableInfo, snapshot) +} + +// DumpStatsToJSONBySnapshot dumps statistic to json. +func (h *Handle) DumpStatsToJSONBySnapshot(dbName string, tableInfo *model.TableInfo, snapshot uint64) (*JSONTable, error) { pi := tableInfo.GetPartitionInfo() if pi == nil { - return h.tableStatsToJSON(dbName, tableInfo, tableInfo.ID, historyStatsExec) + return h.tableStatsToJSON(dbName, tableInfo, tableInfo.ID, snapshot) } jsonTbl := &JSONTable{ DatabaseName: dbName, @@ -73,7 +84,7 @@ func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, hist Partitions: make(map[string]*JSONTable, len(pi.Definitions)), } for _, def := range pi.Definitions { - tbl, err := h.tableStatsToJSON(dbName, tableInfo, def.ID, historyStatsExec) + tbl, err := h.tableStatsToJSON(dbName, tableInfo, def.ID, snapshot) if err != nil { return nil, errors.Trace(err) } @@ -85,12 +96,12 @@ func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, hist return jsonTbl, nil } -func (h *Handle) tableStatsToJSON(dbName string, tableInfo *model.TableInfo, physicalID int64, historyStatsExec sqlexec.RestrictedSQLExecutor) (*JSONTable, error) { - tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, true, historyStatsExec) +func (h *Handle) tableStatsToJSON(dbName string, tableInfo *model.TableInfo, physicalID int64, snapshot uint64) (*JSONTable, error) { + tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, true, snapshot) if err != nil || tbl == nil { return nil, err } - tbl.Version, tbl.ModifyCount, tbl.Count, err = h.statsMetaByTableIDFromStorage(physicalID, historyStatsExec) + tbl.Version, tbl.ModifyCount, tbl.Count, err = h.statsMetaByTableIDFromStorage(physicalID, snapshot) if err != nil { return nil, err } diff --git a/statistics/handle/gc.go b/statistics/handle/gc.go index 3264485e6b703..7bb4bcd6a426f 100644 --- a/statistics/handle/gc.go +++ b/statistics/handle/gc.go @@ -15,7 +15,6 @@ package handle import ( "context" - "fmt" "time" "github.com/cznic/mathutil" @@ -27,6 +26,7 @@ import ( // GCStats will garbage collect the useless stats info. For dropped tables, we will first update their version so that // other tidb could know that table is deleted. func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error { + ctx := context.Background() // To make sure that all the deleted tables' schema and stats info have been acknowledged to all tidb, // we only garbage collect version before 10 lease. lease := mathutil.MaxInt64(int64(h.Lease()), int64(ddlLease)) @@ -34,8 +34,8 @@ func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error if h.LastUpdateVersion() < offset { return nil } - sql := fmt.Sprintf("select table_id from mysql.stats_meta where version < %d", h.LastUpdateVersion()-offset) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + gcVer := h.LastUpdateVersion() - offset + rows, _, err := h.execRestrictedSQL(ctx, "select table_id from mysql.stats_meta where version < %?", gcVer) if err != nil { return errors.Trace(err) } @@ -48,17 +48,18 @@ func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error } func (h *Handle) gcTableStats(is infoschema.InfoSchema, physicalID int64) error { - sql := fmt.Sprintf("select is_index, hist_id from mysql.stats_histograms where table_id = %d", physicalID) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + rows, _, err := h.execRestrictedSQL(ctx, "select is_index, hist_id from mysql.stats_histograms where table_id = %?", physicalID) if err != nil { return errors.Trace(err) } // The table has already been deleted in stats and acknowledged to all tidb, // we can safely remove the meta info now. if len(rows) == 0 { - sql := fmt.Sprintf("delete from mysql.stats_meta where table_id = %d", physicalID) - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) - return errors.Trace(err) + _, _, err = h.execRestrictedSQL(ctx, "delete from mysql.stats_meta where table_id = %?", physicalID) + if err != nil { + return errors.Trace(err) + } } h.mu.Lock() tbl, ok := h.getTableByPhysicalID(is, physicalID) @@ -99,29 +100,37 @@ func (h *Handle) deleteHistStatsFromKV(physicalID int64, histID int64, isIndex i h.mu.Lock() defer h.mu.Unlock() + ctx := context.Background() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } defer func() { - err = finishTransaction(context.Background(), exec, err) + err = finishTransaction(ctx, exec, err) }() txn, err := h.mu.ctx.Txn(true) if err != nil { return errors.Trace(err) } startTS := txn.StartTS() - sqls := make([]string, 0, 4) // First of all, we update the version. If this table doesn't exist, it won't have any problem. Because we cannot delete anything. - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %? ", startTS, physicalID); err != nil { + return err + } // delete histogram meta - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_histograms where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } // delete top n data - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } // delete all buckets - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d and hist_id = %d and is_index = %d", physicalID, histID, isIndex)) - return execSQLs(context.Background(), exec, sqls) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %? and hist_id = %? and is_index = %?", physicalID, histID, isIndex); err != nil { + return err + } + return nil } // DeleteTableStatsFromKV deletes table statistics from kv. @@ -129,7 +138,7 @@ func (h *Handle) DeleteTableStatsFromKV(physicalID int64) (err error) { h.mu.Lock() defer h.mu.Unlock() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(context.Background(), "begin") + _, err = exec.ExecuteInternal(context.Background(), "begin") if err != nil { return errors.Trace(err) } @@ -140,13 +149,23 @@ func (h *Handle) DeleteTableStatsFromKV(physicalID int64) (err error) { if err != nil { return errors.Trace(err) } + ctx := context.Background() startTS := txn.StartTS() - sqls := make([]string, 0, 5) // We only update the version so that other tidb will know that this table is deleted. - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d ", startTS, physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_histograms where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d", physicalID)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_feedback where table_id = %d", physicalID)) - return execSQLs(context.Background(), exec, sqls) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %? ", startTS, physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_histograms where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %?", physicalID); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_feedback where table_id = %?", physicalID); err != nil { + return err + } + return nil } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 3c681aafb3804..2d8915a8e0478 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -20,12 +20,12 @@ import ( "sync/atomic" "time" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" - "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/sessionctx" @@ -68,7 +68,7 @@ type Handle struct { atomic.Value } - restrictedExec sqlexec.RestrictedSQLExecutor + pool sessionPool // ddlEventCh is a channel to notify a ddl operation has happened. // It is sent only by owner or the drop stats executor, and read by stats handle. @@ -83,6 +83,37 @@ type Handle struct { lease atomic2.Duration } +func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context.Context, sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { + se, err := h.pool.Get() + if err != nil { + return nil, nil, errors.Trace(err) + } + defer h.pool.Put(se) + + exec := se.(sqlexec.RestrictedSQLExecutor) + return fn(ctx, exec) +} + +func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := exec.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + return exec.ExecRestrictedStmt(ctx, stmt) + }) +} + +func (h *Handle) execRestrictedSQLWithSnapshot(ctx context.Context, sql string, snapshot uint64, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := exec.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + return exec.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) + }) +} + // Clear the statsCache, only for test. func (h *Handle) Clear() { h.mu.Lock() @@ -101,19 +132,21 @@ func (h *Handle) Clear() { h.mu.Unlock() } +type sessionPool interface { + Get() (pools.Resource, error) + Put(pools.Resource) +} + // NewHandle creates a Handle for update stats. -func NewHandle(ctx sessionctx.Context, lease time.Duration) *Handle { +func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool) *Handle { handle := &Handle{ ddlEventCh: make(chan *util.Event, 100), listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, globalMap: make(tableDeltaMap), feedback: statistics.NewQueryFeedbackMap(), + pool: pool, } handle.lease.Store(lease) - // It is safe to use it concurrently because the exec won't touch the ctx. - if exec, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { - handle.restrictedExec = exec - } handle.mu.ctx = ctx handle.mu.rateMap = make(errorRateDeltaMap) handle.statsCache.Store(statsCache{tables: make(map[int64]*statistics.Table)}) @@ -158,8 +191,8 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { } else { lastVersion = 0 } - sql := fmt.Sprintf("SELECT version, table_id, modify_count, count from mysql.stats_meta where version > %d order by version", lastVersion) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + rows, _, err := h.execRestrictedSQL(ctx, "SELECT version, table_id, modify_count, count from mysql.stats_meta where version > %? order by version", lastVersion) if err != nil { return errors.Trace(err) } @@ -181,7 +214,7 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { continue } tableInfo := table.Meta() - tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, false, nil) + tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, false, 0) // Error is not nil may mean that there are some ddl changes on this table, we will not update it. if err != nil { logutil.BgLogger().Error("[stats] error occurred when read table stats", zap.String("table", tableInfo.Name.O), zap.Error(err)) @@ -281,7 +314,7 @@ func (sc statsCache) update(tables []*statistics.Table, deletedIDs []int64, newV // LoadNeededHistograms will load histograms for those needed columns. func (h *Handle) LoadNeededHistograms() (err error) { cols := statistics.HistogramNeededColumns.AllCols() - reader, err := h.getStatsReader(nil) + reader, err := h.getStatsReader(0) if err != nil { return err } @@ -355,13 +388,11 @@ func (h *Handle) FlushStats() { } func (h *Handle) cmSketchFromStorage(reader *statsReader, tblID int64, isIndex, histID int64) (_ *statistics.CMSketch, err error) { - selSQL := fmt.Sprintf("select cm_sketch from mysql.stats_histograms where table_id = %d and is_index = %d and hist_id = %d", tblID, isIndex, histID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil || len(rows) == 0 { return nil, err } - selSQL = fmt.Sprintf("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %d and is_index = %d and hist_id = %d", tblID, isIndex, histID) - topNRows, _, err := reader.read(selSQL) + topNRows, _, err := reader.read("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil { return nil, err } @@ -497,8 +528,8 @@ func (h *Handle) columnStatsFromStorage(reader *statsReader, row chunk.Row, tabl } // tableStatsFromStorage loads table stats info from storage. -func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID int64, loadAll bool, historyStatsExec sqlexec.RestrictedSQLExecutor) (_ *statistics.Table, err error) { - reader, err := h.getStatsReader(historyStatsExec) +func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID int64, loadAll bool, snapshot uint64) (_ *statistics.Table, err error) { + reader, err := h.getStatsReader(snapshot) if err != nil { return nil, err } @@ -511,7 +542,7 @@ func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID in table, ok := h.statsCache.Load().(statsCache).tables[physicalID] // If table stats is pseudo, we also need to copy it, since we will use the column stats when // the average error rate of it is small. - if !ok || historyStatsExec != nil { + if !ok || snapshot > 0 { histColl := statistics.HistColl{ PhysicalID: physicalID, HavePhysicalID: true, @@ -526,8 +557,7 @@ func (h *Handle) tableStatsFromStorage(tableInfo *model.TableInfo, physicalID in table = table.Copy() } table.Pseudo = false - selSQL := fmt.Sprintf("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %d", physicalID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", physicalID) // Check deleted table. if err != nil || len(rows) == 0 { return nil, nil @@ -551,7 +581,7 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } @@ -564,29 +594,40 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg } version := txn.StartTS() - sqls := make([]string, 0, 4) // If the count is less than 0, then we do not want to update the modify count and count. if count >= 0 { - sqls = append(sqls, fmt.Sprintf("replace into mysql.stats_meta (version, table_id, count) values (%d, %d, %d)", version, tableID, count)) + _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_meta (version, table_id, count) values (%?, %?, %?)", version, tableID, count) } else { - sqls = append(sqls, fmt.Sprintf("update mysql.stats_meta set version = %d where table_id = %d", version, tableID)) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %? where table_id = %?", version, tableID) + } + if err != nil { + return err } data, err := statistics.EncodeCMSketchWithoutTopN(cms) if err != nil { - return + return err } // Delete outdated data - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_top_n where table_id = %d and is_index = %d and hist_id = %d", tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, hg.ID); err != nil { + return err + } for _, meta := range cms.TopN() { - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_top_n (table_id, is_index, hist_id, value, count) values (%d, %d, %d, X'%X', %d)", tableID, isIndex, hg.ID, meta.Data, meta.Count)) + _, err = exec.ExecuteInternal(ctx, "insert into mysql.stats_top_n (table_id, is_index, hist_id, value, count) values (%?, %?, %?, %?, %?)", tableID, isIndex, hg.ID, meta.Data, meta.Count) + if err != nil { + return err + } } flag := 0 if isAnalyzed == 1 { flag = statistics.AnalyzeFlag } - sqls = append(sqls, fmt.Sprintf("replace into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, flag, correlation) values (%d, %d, %d, %d, %d, %d, X'%X', %d, %d, %d, %f)", - tableID, isIndex, hg.ID, hg.NDV, version, hg.NullCount, data, hg.TotColSize, statistics.CurStatsVersion, flag, hg.Correlation)) - sqls = append(sqls, fmt.Sprintf("delete from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d", tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, flag, correlation) values (%?, %?, %?, %?, %?, %?, %?, %?, %?, %?, %?)", + tableID, isIndex, hg.ID, hg.NDV, version, hg.NullCount, data, hg.TotColSize, statistics.CurStatsVersion, flag, hg.Correlation); err != nil { + return err + } + if _, err = exec.ExecuteInternal(ctx, "delete from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, hg.ID); err != nil { + return err + } sc := h.mu.ctx.GetSessionVars().StmtCtx var lastAnalyzePos []byte for i := range hg.Buckets { @@ -607,12 +648,16 @@ func (h *Handle) SaveStatsToStorage(tableID int64, count int64, isIndex int, hg if err != nil { return } - sqls = append(sqls, fmt.Sprintf("insert into mysql.stats_buckets(table_id, is_index, hist_id, bucket_id, count, repeats, lower_bound, upper_bound) values(%d, %d, %d, %d, %d, %d, X'%X', X'%X')", tableID, isIndex, hg.ID, i, count, hg.Buckets[i].Repeat, lowerBound.GetBytes(), upperBound.GetBytes())) + if _, err = exec.ExecuteInternal(ctx, "insert into mysql.stats_buckets(table_id, is_index, hist_id, bucket_id, count, repeats, lower_bound, upper_bound) values(%?, %?, %?, %?, %?, %?, %?, %?)", tableID, isIndex, hg.ID, i, count, hg.Buckets[i].Repeat, lowerBound.GetBytes(), upperBound.GetBytes()); err != nil { + return err + } } if isAnalyzed == 1 && len(lastAnalyzePos) > 0 { - sqls = append(sqls, fmt.Sprintf("update mysql.stats_histograms set last_analyze_pos = X'%X' where table_id = %d and is_index = %d and hist_id = %d", lastAnalyzePos, tableID, isIndex, hg.ID)) + if _, err = exec.ExecuteInternal(ctx, "update mysql.stats_histograms set last_analyze_pos = %? where table_id = %? and is_index = %? and hist_id = %?", lastAnalyzePos, tableID, isIndex, hg.ID); err != nil { + return err + } } - return execSQLs(context.Background(), exec, sqls) + return } // SaveMetaToStorage will save stats_meta to storage. @@ -621,7 +666,7 @@ func (h *Handle) SaveMetaToStorage(tableID, count, modifyCount int64) (err error defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return errors.Trace(err) } @@ -632,16 +677,13 @@ func (h *Handle) SaveMetaToStorage(tableID, count, modifyCount int64) (err error if err != nil { return errors.Trace(err) } - var sql string version := txn.StartTS() - sql = fmt.Sprintf("replace into mysql.stats_meta (version, table_id, count, modify_count) values (%d, %d, %d, %d)", version, tableID, count, modifyCount) - _, err = exec.Execute(ctx, sql) - return + _, err = exec.ExecuteInternal(ctx, "replace into mysql.stats_meta (version, table_id, count, modify_count) values (%?, %?, %?, %?)", version, tableID, count, modifyCount) + return err } func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID int64, tp *types.FieldType, distinct int64, isIndex int, ver uint64, nullCount int64, totColSize int64, corr float64) (_ *statistics.Histogram, err error) { - selSQL := fmt.Sprintf("select count, repeats, lower_bound, upper_bound from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d order by bucket_id", tableID, isIndex, colID) - rows, fields, err := reader.read(selSQL) + rows, fields, err := reader.read("select count, repeats, lower_bound, upper_bound from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) if err != nil { return nil, errors.Trace(err) } @@ -677,8 +719,7 @@ func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID } func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID int64) (int64, error) { - selSQL := fmt.Sprintf("select sum(count) from mysql.stats_buckets where table_id = %d and is_index = %d and hist_id = %d", tableID, 0, colID) - rows, _, err := reader.read(selSQL) + rows, _, err := reader.read("select sum(count) from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %?", tableID, 0, colID) if err != nil { return 0, errors.Trace(err) } @@ -688,13 +729,16 @@ func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID int6 return rows[0].GetMyDecimal(0).ToInt() } -func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, historyStatsExec sqlexec.RestrictedSQLExecutor) (version uint64, modifyCount, count int64, err error) { - selSQL := fmt.Sprintf("SELECT version, modify_count, count from mysql.stats_meta where table_id = %d order by version", tableID) +func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, snapshot uint64) (version uint64, modifyCount, count int64, err error) { + ctx := context.Background() var rows []chunk.Row - if historyStatsExec == nil { - rows, _, err = h.restrictedExec.ExecRestrictedSQL(selSQL) + if snapshot == 0 { + rows, _, err = h.execRestrictedSQL(ctx, "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) } else { - rows, _, err = historyStatsExec.ExecRestrictedSQLWithSnapshot(selSQL) + rows, _, err = h.execRestrictedSQLWithSnapshot(ctx, "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", snapshot, tableID) + if err != nil { + return 0, 0, 0, err + } } if err != nil || len(rows) == 0 { return @@ -708,49 +752,34 @@ func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, historyStatsExec s // statsReader is used for simplify code that needs to read system tables in different sqls // but requires the same transactions. type statsReader struct { - ctx sessionctx.Context - history sqlexec.RestrictedSQLExecutor + ctx sqlexec.RestrictedSQLExecutor + snapshot uint64 } -func (sr *statsReader) read(sql string) (rows []chunk.Row, fields []*ast.ResultField, err error) { - if sr.history != nil { - return sr.history.ExecRestrictedSQLWithSnapshot(sql) - } - rc, err := sr.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } +func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { + ctx := context.TODO() + stmt, err := sr.ctx.ParseWithParams(ctx, sql, args...) if err != nil { - return nil, nil, err + return nil, nil, errors.Trace(err) } - for { - req := rc[0].NewChunk() - err := rc[0].Next(context.TODO(), req) - if err != nil { - return nil, nil, err - } - if req.NumRows() == 0 { - break - } - for i := 0; i < req.NumRows(); i++ { - rows = append(rows, req.GetRow(i)) - } + if sr.snapshot > 0 { + return sr.ctx.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(sr.snapshot)) } - return rows, rc[0].Fields(), nil + return sr.ctx.ExecRestrictedStmt(ctx, stmt) } func (sr *statsReader) isHistory() bool { - return sr.history != nil + return sr.snapshot > 0 } -func (h *Handle) getStatsReader(history sqlexec.RestrictedSQLExecutor) (reader *statsReader, err error) { +func (h *Handle) getStatsReader(snapshot uint64) (reader *statsReader, err error) { failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, errors.New("gofail genStatsReader error")) } }) - if history != nil { - return &statsReader{history: history}, nil + if snapshot > 0 { + return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor), snapshot: snapshot}, nil } h.mu.Lock() defer func() { @@ -762,18 +791,18 @@ func (h *Handle) getStatsReader(history sqlexec.RestrictedSQLExecutor) (reader * } }() failpoint.Inject("mockGetStatsReaderPanic", nil) - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "begin") + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "begin") if err != nil { return nil, err } - return &statsReader{ctx: h.mu.ctx}, nil + return &statsReader{ctx: h.mu.ctx.(sqlexec.RestrictedSQLExecutor)}, nil } func (h *Handle) releaseStatsReader(reader *statsReader) error { - if reader.history != nil { + if reader.snapshot > 0 { return nil } - _, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), "commit") + _, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "commit") h.mu.Unlock() return err } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index eefcc08dee748..32b491ce7ef35 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -265,7 +265,7 @@ func (s *testStatsSuite) TestVersion(c *C) { tbl1, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) c.Assert(err, IsNil) tableInfo1 := tbl1.Meta() - h := handle.NewHandle(testKit.Se, time.Millisecond) + h := handle.NewHandle(testKit.Se, time.Millisecond, do.SysSessionPool()) unit := oracle.ComposeTS(1, 0) testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", 2*unit, tableInfo1.ID) @@ -499,7 +499,7 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "0") - c.Assert(result.Rows()[1][9], Equals, "0.828571") + c.Assert(result.Rows()[1][9], Equals, "0.8285714285714286") testKit.MustExec("truncate table t") result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() @@ -515,7 +515,7 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "0") - c.Assert(result.Rows()[1][9], Equals, "-0.942857") + c.Assert(result.Rows()[1][9], Equals, "-0.9428571428571428") testKit.MustExec("truncate table t") testKit.MustExec("insert into t values (1,1),(2,1),(3,1),(4,1),(5,1),(6,1),(7,1),(8,1),(9,1),(10,1),(11,1),(12,1),(13,1),(14,1),(15,1),(16,1),(17,1),(18,1),(19,1),(20,2),(21,2),(22,2),(23,2),(24,2),(25,2)") @@ -532,14 +532,14 @@ func (s *testStatsSuite) TestCorrelation(c *C) { result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) c.Assert(result.Rows()[0][9], Equals, "1") - c.Assert(result.Rows()[1][9], Equals, "0.828571") + c.Assert(result.Rows()[1][9], Equals, "0.8285714285714286") testKit.MustExec("truncate table t") testKit.MustExec("insert into t values(1,1),(2,7),(3,12),(8,18),(4,20),(5,21)") testKit.MustExec("analyze table t") result = testKit.MustQuery("show stats_histograms where Table_name = 't'").Sort() c.Assert(len(result.Rows()), Equals, 2) - c.Assert(result.Rows()[0][9], Equals, "0.828571") + c.Assert(result.Rows()[0][9], Equals, "0.8285714285714286") c.Assert(result.Rows()[1][9], Equals, "1") testKit.MustExec("drop table t") diff --git a/statistics/handle/update.go b/statistics/handle/update.go index a17357cef2c5b..0ee940b559f05 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -331,7 +331,7 @@ func (h *Handle) dumpTableStatCountToKV(id int64, delta variable.TableDelta) (up defer h.mu.Unlock() ctx := context.TODO() exec := h.mu.ctx.(sqlexec.SQLExecutor) - _, err = exec.Execute(ctx, "begin") + _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return false, errors.Trace(err) } @@ -344,13 +344,14 @@ func (h *Handle) dumpTableStatCountToKV(id int64, delta variable.TableDelta) (up return false, errors.Trace(err) } startTS := txn.StartTS() - var sql string if delta.Delta < 0 { - sql = fmt.Sprintf("update mysql.stats_meta set version = %d, count = count - %d, modify_count = modify_count + %d where table_id = %d and count >= %d", startTS, -delta.Delta, delta.Count, id, -delta.Delta) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %?, count = count - %?, modify_count = modify_count + %? where table_id = %? and count >= %?", startTS, -delta.Delta, delta.Count, id, -delta.Delta) } else { - sql = fmt.Sprintf("update mysql.stats_meta set version = %d, count = count + %d, modify_count = modify_count + %d where table_id = %d", startTS, delta.Delta, delta.Count, id) + _, err = exec.ExecuteInternal(ctx, "update mysql.stats_meta set version = %?, count = count + %?, modify_count = modify_count + %? where table_id = %?", startTS, delta.Delta, delta.Count, id) + } + if err != nil { + return false, errors.Trace(err) } - err = execSQLs(context.Background(), exec, []string{sql}) updated = h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() > 0 return } @@ -371,7 +372,7 @@ func (h *Handle) dumpTableStatColSizeToKV(id int64, delta variable.TableDelta) e } sql := fmt.Sprintf("insert into mysql.stats_histograms (table_id, is_index, hist_id, distinct_count, tot_col_size) "+ "values %s on duplicate key update tot_col_size = tot_col_size + values(tot_col_size)", strings.Join(values, ",")) - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + _, _, err := h.execRestrictedSQL(context.Background(), sql) return errors.Trace(err) } @@ -409,10 +410,9 @@ func (h *Handle) DumpFeedbackToKV(fb *statistics.QueryFeedback) error { if fb.Tp == statistics.IndexType { isIndex = 1 } - sql := fmt.Sprintf("insert into mysql.stats_feedback (table_id, hist_id, is_index, feedback) values "+ - "(%d, %d, %d, X'%X')", fb.PhysicalID, fb.Hist.ID, isIndex, vals) + const sql = "insert into mysql.stats_feedback (table_id, hist_id, is_index, feedback) values (%?, %?, %?, %?)" h.mu.Lock() - _, err = h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, fb.PhysicalID, fb.Hist.ID, isIndex, vals) h.mu.Unlock() if err != nil { metrics.DumpFeedbackCounter.WithLabelValues(metrics.LblError).Inc() @@ -503,8 +503,8 @@ func (h *Handle) UpdateErrorRate(is infoschema.InfoSchema) { // HandleUpdateStats update the stats using feedback. func (h *Handle) HandleUpdateStats(is infoschema.InfoSchema) error { - sql := "SELECT distinct table_id from mysql.stats_feedback" - tables, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + tables, _, err := h.execRestrictedSQL(ctx, "SELECT distinct table_id from mysql.stats_feedback") if err != nil { return errors.Trace(err) } @@ -516,20 +516,18 @@ func (h *Handle) HandleUpdateStats(is infoschema.InfoSchema) error { // this func lets `defer` works normally, where `Close()` should be called before any return err = func() error { tbl := ptbl.GetInt64(0) - sql = fmt.Sprintf("select table_id, hist_id, is_index, feedback from mysql.stats_feedback where table_id=%d order by hist_id, is_index", tbl) - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) - if len(rc) > 0 { - defer terror.Call(rc[0].Close) - } + const sql = "select table_id, hist_id, is_index, feedback from mysql.stats_feedback where table_id=%? order by hist_id, is_index" + rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, tbl) if err != nil { return errors.Trace(err) } + defer terror.Call(rc.Close) tableID, histID, isIndex := int64(-1), int64(-1), int64(-1) var rows []chunk.Row for { - req := rc[0].NewChunk() + req := rc.NewChunk() iter := chunk.NewIterator4Chunk(req) - err := rc[0].Next(context.TODO(), req) + err := rc.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } @@ -622,8 +620,8 @@ func (h *Handle) deleteOutdatedFeedback(tableID, histID, isIndex int64) error { defer h.mu.Unlock() hasData := true for hasData { - sql := fmt.Sprintf("delete from mysql.stats_feedback where table_id = %d and hist_id = %d and is_index = %d limit 10000", tableID, histID, isIndex) - _, err := h.mu.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + sql := "delete from mysql.stats_feedback where table_id = %? and hist_id = %? and is_index = %? limit 10000" + _, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql, tableID, histID, isIndex) if err != nil { return errors.Trace(err) } @@ -694,9 +692,9 @@ func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRat } func (h *Handle) getAutoAnalyzeParameters() map[string]string { - sql := fmt.Sprintf("select variable_name, variable_value from mysql.global_variables where variable_name in ('%s', '%s', '%s')", - variable.TiDBAutoAnalyzeRatio, variable.TiDBAutoAnalyzeStartTime, variable.TiDBAutoAnalyzeEndTime) - rows, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + ctx := context.Background() + sql := "select variable_name, variable_value from mysql.global_variables where variable_name in (%?, %?, %?)" + rows, _, err := h.execRestrictedSQL(ctx, sql, variable.TiDBAutoAnalyzeRatio, variable.TiDBAutoAnalyzeStartTime, variable.TiDBAutoAnalyzeEndTime) if err != nil { return map[string]string{} } @@ -745,11 +743,10 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { for _, tbl := range tbls { tblInfo := tbl.Meta() pi := tblInfo.GetPartitionInfo() - tblName := "`" + db + "`.`" + tblInfo.Name.O + "`" if pi == nil { statsTbl := h.GetTableStats(tblInfo) - sql := fmt.Sprintf("analyze table %s", tblName) - analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) + sql := "analyze table %n.%n" + analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql, db, tblInfo.Name.O) if analyzed { // analyze one table at a time to let it get the freshest parameters. // others will be analyzed next round which is just 3s later. @@ -758,9 +755,9 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { continue } for _, def := range pi.Definitions { - sql := fmt.Sprintf("analyze table %s partition `%s`", tblName, def.Name.O) + sql := "analyze table %n.%n partition %n" statsTbl := h.GetPartitionStats(tblInfo, def.ID) - analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql) + analyzed := h.autoAnalyzeTable(tblInfo, statsTbl, start, end, autoAnalyzeRatio, sql, db, tblInfo.Name.O, def.Name.O) if analyzed { return true } @@ -771,29 +768,28 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { return false } -func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics.Table, start, end time.Time, ratio float64, sql string) bool { +func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics.Table, start, end time.Time, ratio float64, sql string, params ...interface{}) bool { if statsTbl.Pseudo || statsTbl.Count < AutoAnalyzeMinCnt { return false } if needAnalyze, reason := NeedAnalyzeTable(statsTbl, 20*h.Lease(), ratio, start, end, time.Now()); needAnalyze { logutil.BgLogger().Info("[stats] auto analyze triggered", zap.String("sql", sql), zap.String("reason", reason)) - h.execAutoAnalyze(sql) + h.execAutoAnalyze(sql, params...) return true } for _, idx := range tblInfo.Indices { if _, ok := statsTbl.Indices[idx.ID]; !ok && idx.State == model.StatePublic { - sql = fmt.Sprintf("%s index `%s`", sql, idx.Name.O) logutil.BgLogger().Info("[stats] auto analyze for unanalyzed", zap.String("sql", sql)) - h.execAutoAnalyze(sql) + h.execAutoAnalyze(sql+" index %n", append(params, idx.Name.O)...) return true } } return false } -func (h *Handle) execAutoAnalyze(sql string) { +func (h *Handle) execAutoAnalyze(sql string, params ...interface{}) { startTime := time.Now() - _, _, err := h.restrictedExec.ExecRestrictedSQL(sql) + _, _, err := h.execRestrictedSQL(context.Background(), sql, params...) dur := time.Since(startTime) metrics.AutoAnalyzeHistogram.Observe(dur.Seconds()) if err != nil { diff --git a/types/datum.go b/types/datum.go index 08c989d366a74..7a72024488006 100644 --- a/types/datum.go +++ b/types/datum.go @@ -194,7 +194,10 @@ var sink = func(s string) { // GetBytes gets bytes value. func (d *Datum) GetBytes() []byte { - return d.b + if d.b != nil { + return d.b + } + return []byte{} } // SetBytes sets bytes value to datum. From 4a672e77b03b239e08bb9dd280437af64776c600 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:56:56 +0800 Subject: [PATCH 32/85] ddl: fix ddl hang over when it meets panic in cancelling path (#23204) (#23297) --- ddl/db_test.go | 51 +++++++++++++++++++++++++++++++++++ ddl/ddl_worker.go | 69 ++++++++++++++++++++++++++++++++--------------- ddl/index.go | 6 +++++ 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/ddl/db_test.go b/ddl/db_test.go index 262677eeca2a4..109d39aff9be5 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -4988,3 +4988,54 @@ func (s *testSerialDBSuite) TestColumnTypeChangeIgnoreDisplayLength(c *C) { tk.MustExec("alter table t modify column a bigint(1)") tk.MustExec("drop table if exists t") } + +// Close issue #23202 +func (s *testSerialDBSuite) TestDDLExitWhenCancelMeetPanic(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values(1,1),(2,2)") + tk.MustExec("alter table t add index(b)") + tk.MustExec("set @@global.tidb_ddl_error_count_limit=3") + + failpoint.Enable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit", `return(true)`) + defer func() { + failpoint.Disable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit") + }() + + originalHook := s.dom.DDL().GetHook() + defer s.dom.DDL().(ddl.DDLForTest).SetHook(originalHook) + + hook := &ddl.TestDDLCallback{Do: s.dom} + var jobID int64 + hook.OnJobRunBeforeExported = func(job *model.Job) { + if jobID != 0 { + return + } + if job.Type == model.ActionDropIndex { + jobID = job.ID + } + } + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + + // when it panics in write-reorg state, the job will be pulled up as a cancelling job. Since drop-index with + // write-reorg can't be cancelled, so it will be converted to running state and try again (dead loop). + _, err := tk.Exec("alter table t drop index b") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:-1]panic in handling DDL logic and error count beyond the limitation 3, cancelled") + c.Assert(jobID > 0, Equals, true) + + // Verification of the history job state. + var job *model.Job + err = kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { + t := meta.NewMeta(txn) + var err1 error + job, err1 = t.GetHistoryDDLJob(jobID) + return errors.Trace(err1) + }) + c.Assert(err, IsNil) + c.Assert(job.ErrorCount, Equals, int64(4)) + c.Assert(job.Error.Error(), Equals, "[ddl:-1]panic in handling DDL logic and error count beyond the limitation 3, cancelled") +} diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index 00280bbdd0b6a..5870624dde5ae 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -582,12 +582,55 @@ func chooseLeaseTime(t, max time.Duration) time.Duration { return t } +// countForPanic records the error count for DDL job. +func (w *worker) countForPanic(job *model.Job) { + // If run DDL job panic, just cancel the DDL jobs. + job.State = model.JobStateCancelling + job.ErrorCount++ + + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + + if job.ErrorCount > errorCount { + msg := fmt.Sprintf("panic in handling DDL logic and error count beyond the limitation %d, cancelled", errorCount) + logutil.Logger(w.logCtx).Warn(msg) + job.Error = toTError(errors.New(msg)) + job.State = model.JobStateCancelled + } +} + +// countForError records the error count for DDL job. +func (w *worker) countForError(err error, job *model.Job) error { + job.Error = toTError(err) + job.ErrorCount++ + + // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. + if job.State == model.JobStateCancelled { + logutil.Logger(w.logCtx).Info("[ddl] DDL job is cancelled normally", zap.Error(err)) + return nil + } + logutil.Logger(w.logCtx).Error("[ddl] run DDL job error", zap.Error(err)) + + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + // Check error limit to avoid falling into an infinite loop. + if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && admin.IsJobRollbackable(job) { + logutil.Logger(w.logCtx).Warn("[ddl] DDL job error count exceed the limit, cancelling it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) + job.State = model.JobStateCancelling + } + return err +} + // runDDLJob runs a DDL job. It returns the current schema version in this transaction and the error. func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { defer tidbutil.Recover(metrics.LabelDDLWorker, fmt.Sprintf("%s runDDLJob", w), func() { - // If run DDL job panic, just cancel the DDL jobs. - job.State = model.JobStateCancelling + w.countForPanic(job) }, false) // Mock for run ddl job panic. @@ -690,27 +733,9 @@ func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err = errInvalidDDLJob.GenWithStack("invalid ddl job type: %v", job.Type) } - // Save errors in job, so that others can know errors happened. + // Save errors in job if any, so that others can know errors happened. if err != nil { - job.Error = toTError(err) - job.ErrorCount++ - - // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. - if job.State == model.JobStateCancelled { - logutil.Logger(w.logCtx).Info("[ddl] DDL job is cancelled normally", zap.Error(err)) - return ver, nil - } - logutil.Logger(w.logCtx).Error("[ddl] run DDL job error", zap.Error(err)) - - // Load global ddl variables. - if err1 := loadDDLVars(w); err1 != nil { - logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) - } - // Check error limit to avoid falling into an infinite loop. - if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && admin.IsJobRollbackable(job) { - logutil.Logger(w.logCtx).Warn("[ddl] DDL job error count exceed the limit, cancelling it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) - job.State = model.JobStateCancelling - } + err = w.countForError(err, job) } return } diff --git a/ddl/index.go b/ddl/index.go index a05a125d9af94..e6cb693cd1c3c 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -618,6 +618,12 @@ func onDropIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { // Set column index flag. dropIndexColumnFlag(tblInfo, indexInfo) + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + if val.(bool) { + panic("panic test in cancelling add index") + } + }) + tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-len(dependentHiddenCols)] ver, err = updateVersionAndTableInfo(t, job, tblInfo, originalState != model.StateNone) From a6d8bcb9298e5dc0c58d20c4eaf6af72385d6ab1 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 16 Mar 2021 16:52:55 +0800 Subject: [PATCH 33/85] executor: fix unexpected NotNullFlag in case when expr ret type (#23102) (#23135) --- expression/builtin_control.go | 3 +++ expression/integration_test.go | 10 ++++++++++ types/field_type.go | 7 ++++--- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 69b61323bc535..91e3cc526694e 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -173,6 +173,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } fieldTp := types.AggFieldType(fieldTps) + // Here we turn off NotNullFlag. Because if all when-clauses are false, + // the result of case-when expr is NULL. + types.SetTypeFlag(&fieldTp.Flag, mysql.NotNullFlag, false) tp := fieldTp.EvalType() if tp == types.ETInt { diff --git a/expression/integration_test.go b/expression/integration_test.go index fb2bbfbbf1cca..1cebcc6cbedbb 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2920,6 +2920,16 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { result.Check(testkit.Rows(" 4")) result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end") result.Check(testkit.Rows(" 4")) + + // return type of case when expr should not include NotNullFlag. issue-23036 + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int not null)") + tk.MustExec("insert into t1 values(1)") + result = tk.MustQuery("select (case when null then c1 end) is null from t1") + result.Check(testkit.Rows("1")) + result = tk.MustQuery("select (case when null then c1 end) is not null from t1") + result.Check(testkit.Rows("0")) + // test warnings tk.MustQuery("select case when b=0 then 1 else 1/b end from t") tk.MustQuery("show warnings").Check(testkit.Rows()) diff --git a/types/field_type.go b/types/field_type.go index b93147564c5fe..9e470892875a1 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -133,8 +133,8 @@ func AggregateEvalType(fts []*FieldType, flag *uint) EvalType { } lft = rft } - setTypeFlag(flag, mysql.UnsignedFlag, unsigned) - setTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString) + SetTypeFlag(flag, mysql.UnsignedFlag, unsigned) + SetTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString) return aggregatedEvalType } @@ -159,7 +159,8 @@ func mergeEvalType(lhs, rhs EvalType, lft, rft *FieldType, isLHSUnsigned, isRHSU return ETInt } -func setTypeFlag(flag *uint, flagItem uint, on bool) { +// SetTypeFlag turns the flagItem on or off. +func SetTypeFlag(flag *uint, flagItem uint, on bool) { if on { *flag |= flagItem } else { From f055086260441924609723cc1ecfdc1e82672212 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 17 Mar 2021 09:58:55 +0800 Subject: [PATCH 34/85] *: adapt new api for the executor package (#22644) (#23156) --- executor/brie.go | 1 + executor/ddl.go | 8 +- executor/grant.go | 342 +++++++++------------ executor/infoschema_reader.go | 16 +- executor/inspection_profile.go | 8 +- executor/inspection_result.go | 154 +++++++--- executor/inspection_summary.go | 7 +- executor/metrics_reader.go | 16 +- executor/opt_rule_blacklist.go | 8 +- executor/prepared.go | 1 + executor/reload_expr_pushdown_blacklist.go | 8 +- executor/revoke.go | 103 ++++++- executor/show.go | 64 +++- executor/simple.go | 281 ++++++++++------- executor/trace.go | 19 +- executor/utils.go | 46 +++ telemetry/data_cluster_hardware.go | 8 +- telemetry/data_cluster_info.go | 9 +- util/mock/context.go | 6 + util/sqlexec/restricted_sql_executor.go | 1 + 20 files changed, 694 insertions(+), 412 deletions(-) create mode 100644 executor/utils.go diff --git a/executor/brie.go b/executor/brie.go index 0d14265a8fd0e..492a3d68d8121 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -404,6 +404,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // Execute implements glue.Session func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { + // FIXME: br relies on a deprecated API, it may be unsafe _, err := gs.se.(sqlexec.SQLExecutor).Execute(ctx, sql) return err } diff --git a/executor/ddl.go b/executor/ddl.go index ae2d685a65ccf..ef903bd4e377d 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -309,8 +309,12 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("database", fullti.Schema.O), zap.String("table", fullti.Name.O), ) - sql := fmt.Sprintf("admin check table `%s`.`%s`", fullti.Schema.O, fullti.Name.O) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/grant.go b/executor/grant.go index 097e2d5dcc349..86b279727aaa5 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -16,7 +16,6 @@ package executor import ( "context" "encoding/json" - "fmt" "strings" "github.com/pingcap/errors" @@ -106,7 +105,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -114,7 +113,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -132,9 +131,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd) - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user) - _, err := internalSession.(sqlexec.SQLExecutor).Execute(ctx, sql) + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(ctx, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES (%?, %?, %?);`, mysql.SystemDB, mysql.UserTable, user.User.Hostname, user.User.Username, pwd) if err != nil { return err } @@ -193,7 +190,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -274,29 +271,25 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast // initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege. func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, PRIV) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") return err } // initDBPrivEntry inserts a new row into mysql.DB with empty privilege. func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB) VALUES (%?, %?, %?)`, mysql.SystemDB, mysql.DBTable, host, user, db) return err } // initTablePrivEntry inserts a new row into mysql.Tables_priv with empty privilege. func initTablePrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES ('%s', '%s', '%s', '%s', '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES (%?, %?, %?, %?, '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl) return err } // initColumnPrivEntry inserts a new row into mysql.Columns_priv with empty privilege. func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string, col string) error { - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Column_name, Column_priv) VALUES ('%s', '%s', '%s', '%s', '%s', '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) - _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `INSERT INTO %n.%n (Host, User, DB, Table_name, Column_name, Column_priv) VALUES (%?, %?, %?, %?, %?, '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col) return err } @@ -309,8 +302,7 @@ func (e *GrantExec) grantGlobalPriv(ctx sessionctx.Context, user *ast.UserSpec) if err != nil { return errors.Trace(err) } - sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) - _, err = ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, err = ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `UPDATE %n.%n SET PRIV=%? WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) return err } @@ -415,12 +407,16 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, int if priv.Priv == 0 { return nil } - asgns, err := composeGlobalPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, `UPDATE %n.%n SET `, mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user.User.Username, user.User.Hostname) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, ` WHERE User=%? AND Host=%?`, user.User.Username, user.User.Hostname) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -430,12 +426,16 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, interna if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "Y") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "Y") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, user.User.Username, user.User.Hostname, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", user.User.Username, user.User.Hostname, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -446,12 +446,16 @@ func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, inte dbName = e.ctx.GetSessionVars().CurrentDB } tblName := e.Level.TableName - asgns, err := composeTablePrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err := composeTablePrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tblName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user.User.Username, user.User.Hostname, dbName, tblName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -467,12 +471,16 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err := composeColumnPrivUpdateForGrant(internalSession, sql, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) if err != nil { return err } @@ -481,178 +489,143 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, int } // composeGlobalPrivUpdate composes update stmt assignment list string for global scope privilege update. -func composeGlobalPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.Priv2UserCol)) - for _, v := range mysql.AllGlobalPrivs { - strs = append(strs, fmt.Sprintf(`%s='%s'`, mysql.Priv2UserCol[v], value)) +func composeGlobalPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + + for i, v := range mysql.AllGlobalPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + + k, ok := mysql.Priv2UserCol[v] + if !ok { + return errors.Errorf("Unknown priv %v", priv) + } + + sqlexec.MustFormatSQL(sql, "%n=%?", k, value) } - return fmt.Sprintf(`%s='%s'`, col, value), nil + return nil } // composeDBPrivUpdate composes update stmt assignment list for db scope privilege update. -func composeDBPrivUpdate(priv mysql.PrivilegeType, value string) (string, error) { - if priv == mysql.AllPriv { - strs := make([]string, 0, len(mysql.AllDBPrivs)) - for _, p := range mysql.AllDBPrivs { - v, ok := mysql.Priv2UserCol[p] - if !ok { - return "", errors.Errorf("Unknown db privilege %v", priv) - } - strs = append(strs, fmt.Sprintf(`%s='%s'`, v, value)) +func composeDBPrivUpdate(sql *strings.Builder, priv mysql.PrivilegeType, value string) error { + if priv != mysql.AllPriv { + col, ok := mysql.Priv2UserCol[priv] + if !ok { + return errors.Errorf("Unknown priv: %v", priv) } - return strings.Join(strs, ", "), nil - } - col, ok := mysql.Priv2UserCol[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + sqlexec.MustFormatSQL(sql, "%n=%?", col, value) + return nil } - return fmt.Sprintf(`%s='%s'`, col, value), nil -} -// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. -func composeTablePrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - for _, p := range mysql.AllTablePrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown table privilege %v", p) - } - newTablePriv = addToSet(newTablePriv, v) - } - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { - currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) - if err != nil { - return "", err + for i, p := range mysql.AllDBPrivs { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") } - p, ok := mysql.Priv2SetStr[priv] + + v, ok := mysql.Priv2UserCol[p] if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return errors.Errorf("Unknown priv %v", priv) } - newTablePriv = addToSet(currTablePriv, p) - for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = addToSet(currColumnPriv, p) - break - } - } + sqlexec.MustFormatSQL(sql, "%n=%?", v, value) } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil + return nil } -func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string) (string, error) { - var newTablePriv, newColumnPriv string - if priv == mysql.AllPriv { - newTablePriv = "" - newColumnPriv = "" - } else { +func privUpdateForGrant(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = addToSet(cur, p) + return cur, nil +} + +// composeTablePrivUpdateForGrant composes update stmt assignment list for table scope privilege update. +func composeTablePrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + var tblPrivs, colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newTablePriv = deleteFromSet(currTablePriv, p) - + newTablePriv = setFromString(currTablePriv) + newColumnPriv = setFromString(currColumnPriv) + tblPrivs = []mysql.PrivilegeType{priv} for _, cp := range mysql.AllColumnPrivs { - if priv == cp { - newColumnPriv = deleteFromSet(currColumnPriv, p) + // in case it is not a column priv + if cp == priv { + colPrivs = []mysql.PrivilegeType{priv} break } } + } else { + tblPrivs = mysql.AllTablePrivs + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Table_priv='%s', Column_priv='%s', Grantor='%s'`, newTablePriv, newColumnPriv, ctx.GetSessionVars().User), nil -} -// addToSet add a value to the set, e.g: -// addToSet("Select,Insert", "Update") returns "Select,Insert,Update". -func addToSet(set string, value string) string { - if set == "" { - return value + var err error + for _, p := range tblPrivs { + newTablePriv, err = privUpdateForGrant(newTablePriv, p) + if err != nil { + return err + } } - return fmt.Sprintf("%s,%s", set, value) -} -// deleteFromSet delete the value from the set, e.g: -// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". -func deleteFromSet(set string, value string) string { - sets := strings.Split(set, ",") - res := make([]string, 0, len(sets)) - for _, v := range sets { - if v != value { - res = append(res, v) + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) + if err != nil { + return err } } - return strings.Join(res, ",") + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil } // composeColumnPrivUpdateForGrant composes update stmt assignment list for column scope privilege update. -func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - for _, p := range mysql.AllColumnPrivs { - v, ok := mysql.Priv2SetStr[p] - if !ok { - return "", errors.Errorf("Unknown column privilege %v", p) - } - newColumnPriv = addToSet(newColumnPriv, v) - } - } else { +func composeColumnPrivUpdateForGrant(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + var colPrivs []mysql.PrivilegeType + if priv != mysql.AllPriv { currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = addToSet(currColumnPriv, p) + newColumnPriv = setFromString(currColumnPriv) + colPrivs = []mysql.PrivilegeType{priv} + } else { + colPrivs = mysql.AllColumnPrivs } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil -} -func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) (string, error) { - newColumnPriv := "" - if priv == mysql.AllPriv { - newColumnPriv = "" - } else { - currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) + var err error + for _, p := range colPrivs { + newColumnPriv, err = privUpdateForGrant(newColumnPriv, p) if err != nil { - return "", err - } - p, ok := mysql.Priv2SetStr[priv] - if !ok { - return "", errors.Errorf("Unknown priv: %v", priv) + return err } - newColumnPriv = deleteFromSet(currColumnPriv, p) } - return fmt.Sprintf(`Column_priv='%s'`, newColumnPriv), nil + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) + return nil } // recordExists is a helper function to check if the sql returns any row. -func recordExists(ctx sessionctx.Context, sql string) (bool, error) { - recordSets, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) +func recordExists(ctx sessionctx.Context, sql string, args ...interface{}) (bool, error) { + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql, args...) if err != nil { return false, err } - rows, _, err := getRowsAndFields(ctx, recordSets) + rows, _, err := getRowsAndFields(ctx, rs) if err != nil { return false, err } @@ -661,43 +634,35 @@ func recordExists(ctx sessionctx.Context, sql string) (bool, error) { // globalPrivEntryExists checks if there is an entry with key user-host in mysql.global_priv. func globalPrivEntryExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) } // dbUserExists checks if there is an entry with key user-host-db in mysql.DB. func dbUserExists(ctx sessionctx.Context, name string, host string, db string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, name, host, db) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%?;`, mysql.SystemDB, mysql.DBTable, name, host, db) } // tableUserExists checks if there is an entry with key user-host-db-tbl in mysql.Tables_priv. func tableUserExists(ctx sessionctx.Context, name string, host string, db string, tbl string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?;`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) } // columnPrivEntryExists checks if there is an entry with key user-host-db-tbl-col in mysql.Columns_priv. func columnPrivEntryExists(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - return recordExists(ctx, sql) + return recordExists(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) } // getTablePriv gets current table scope privilege set from mysql.Tables_priv. // Return Table_priv and Column_priv. func getTablePriv(ctx sessionctx.Context, name string, host string, db string, tbl string) (string, string, error) { - sql := fmt.Sprintf(`SELECT Table_priv, Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Table_priv, Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl) if err != nil { return "", "", err } - if len(rs) < 1 { - return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) - } var tPriv, cPriv string rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", "", err + return "", "", errors.Errorf("get table privilege fail for %s %s %s %s: %v", name, host, db, tbl, err) } if len(rows) < 1 { return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl) @@ -717,17 +682,13 @@ func getTablePriv(ctx sessionctx.Context, name string, host string, db string, t // getColumnPriv gets current column scope privilege set from mysql.Columns_priv. // Return Column_priv. func getColumnPriv(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (string, error) { - sql := fmt.Sprintf(`SELECT Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) - rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + rs, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), `SELECT Column_priv FROM %n.%n WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?;`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col) if err != nil { return "", err } - if len(rs) < 1 { - return "", errors.Errorf("get column privilege fail for %s %s %s %s", name, host, db, tbl) - } rows, fields, err := getRowsAndFields(ctx, rs) if err != nil { - return "", err + return "", errors.Errorf("get column privilege fail for %s %s %s %s: %s", name, host, db, tbl, err) } if len(rows) < 1 { return "", errors.Errorf("get column privilege fail for %s %s %s %s %s", name, host, db, tbl, col) @@ -757,27 +718,18 @@ func getTargetSchemaAndTable(ctx sessionctx.Context, dbName, tableName string, i } // getRowsAndFields is used to extract rows from record sets. -func getRowsAndFields(ctx sessionctx.Context, recordSets []sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - - for i, rs := range recordSets { - tmp, err := getRowFromRecordSet(context.Background(), ctx, rs) - if err != nil { - return nil, nil, err - } - if err = rs.Close(); err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } +func getRowsAndFields(ctx sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) { + if rs == nil { + return nil, nil, errors.Errorf("nil recordset") + } + rows, err := getRowFromRecordSet(context.Background(), ctx, rs) + if err != nil { + return nil, nil, err + } + if err = rs.Close(); err != nil { + return nil, nil, err } - return rows, fields, nil + return rows, rs.Fields(), nil } func getRowFromRecordSet(ctx context.Context, se sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 72880529ad08f..23cacc1558e76 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -154,10 +154,16 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex } func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, count from mysql.stats_meta") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + rowCountMap := make(map[int64]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) @@ -173,10 +179,16 @@ type tableHistID struct { } func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return nil, err + } + colLengthMap := make(map[tableHistID]uint64, len(rows)) for _, row := range rows { tableID := row.GetInt64(0) diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index f243db364f0d8..f15dd6ef5e6ff 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -161,7 +161,13 @@ func (n *metricNode) getLabelValue(label string) *metricValue { } func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { - rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(context.Background(), query) + exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), query) + if err != nil { + return err + } + + rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index 4d8bd69ca3edb..5bb858a8224e9 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -141,8 +141,12 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct // Get cluster info. e.instanceToStatusAddress = make(map[string]string) e.statusToInstanceAddress = make(map[string]string) - sql := "select instance,status_address from information_schema.cluster_info;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,status_address from information_schema.cluster_info;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) } @@ -247,16 +251,22 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C "storage.data-dir", "storage.block-cache.capacity", } - sql := fmt.Sprintf("select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in ('%s') group by type, `key` having c > 1", - strings.Join(ignoreConfigKey, "','")) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) } generateDetail := func(tp, item string) string { - query := fmt.Sprintf("select value, instance from information_schema.cluster_config where type='%s' and `key`='%s';", tp, item) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, query) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", @@ -318,13 +328,18 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, cas := range cases { if !filter.enable(cas.key) { continue } - sql := fmt.Sprintf("select instance from information_schema.cluster_config where type = '%s' and `key` = '%s' and value = '%s'", - cas.tp, cas.key, cas.value) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + stmt, err := exec.ParseWithParams(ctx, "select instance from information_schema.cluster_config where type = %? and %n = %? and value = %?", cas.tp, "key", cas.key, cas.value) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -350,8 +365,12 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct if !filter.enable(item) { return nil } - sql := "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -375,8 +394,10 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - sql = "select instance, value from metrics_schema.node_total_memory where time=now()" - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err = exec.ParseWithParams(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -438,9 +459,13 @@ func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, e } func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + var rows []chunk.Row // check the configuration consistent - sql := "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;" - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) } @@ -594,6 +619,9 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if filter.enable(rule.item) { def, found := infoschema.MetricTableMap[rule.tbl] @@ -601,9 +629,13 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", rule.tbl)) continue } - sql := fmt.Sprintf("select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", + sql.Reset() + fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -649,10 +681,16 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se return nil } condition := filter.timeRange.Condition() - sql := fmt.Sprintf(`select t1.job,t1.instance, t2.min_time from + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -675,8 +713,12 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se results = append(results, result) } // Check from log. - sql = fmt.Sprintf("select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - rows, _, err = sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) + stmt, err = exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -790,24 +832,30 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) + sql := new(strings.Builder) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string + sql.Reset() if len(rule.configKey) > 0 { - sql = fmt.Sprintf("select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from "+ - "(select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join "+ - "(select instance, value from information_schema.cluster_config where type='tikv' and `key` = '%[3]s') as t2 join "+ - "(select instance,status_address from information_schema.cluster_info where type='tikv') as t3 "+ - "on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)", rule.component, rule.threshold, rule.configKey, condition) + fmt.Fprintf(sql, `select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from + (select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join + (select instance, value from information_schema.cluster_config where type='tikv' and %[5]s = '%[3]s') as t2 join + (select instance,status_address from information_schema.cluster_info where type='tikv') as t3 + on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)`, rule.component, rule.threshold, rule.configKey, condition, "`key`") } else { - sql = fmt.Sprintf("select t1.instance, t1.cpu, %[2]f from "+ - "(select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 "+ - "where t1.cpu > %[2]f;", rule.component, rule.threshold, condition) + fmt.Fprintf(sql, `select t1.instance, t1.cpu, %[2]f from + (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 + where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -957,11 +1005,13 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.item) { continue } - var sql string cond := condition if len(rule.condition) > 0 { cond = fmt.Sprintf("%s and %s", cond, rule.condition) @@ -969,12 +1019,16 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess if rule.factor == 0 { rule.factor = 1 } + sql.Reset() if rule.isMin { - sql = fmt.Sprintf("select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) } else { - sql = fmt.Sprintf("select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + } + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1150,12 +1204,17 @@ func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sess func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { var results []inspectionResult + var rows []chunk.Row + exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.getItem()) { continue } sql := rule.genSQL(filter.timeRange) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + stmt, err := exec.ParseWithParams(ctx, sql) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1170,8 +1229,15 @@ func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionF func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { condition := filter.timeRange.Condition() threshold := 50.0 - sql := fmt.Sprintf(`select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql := new(strings.Builder) + fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + + var rows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) return nil @@ -1179,12 +1245,18 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx var results []inspectionResult for _, row := range rows { address := row.GetString(0) - sql := fmt.Sprintf(`select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) - subRows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + sql.Reset() + fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) + var subRows []chunk.Row + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue } + lastValue := float64(0) for i, subRows := range subRows { v := subRows.GetFloat64(1) diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index 37aef042a62f9..2a01cab6dc402 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -458,7 +458,12 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc sql = fmt.Sprintf("select avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index 91cb8159bfc27..ff582c23b9935 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -189,7 +189,7 @@ type MetricsSummaryRetriever struct { retrieved bool } -func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { +func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { if e.retrieved || e.extractor.SkipRequest { return nil, nil } @@ -229,7 +229,12 @@ func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Co name, util.MetricSchemaName.L, condition) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -306,7 +311,12 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s", util.MetricSchemaName.L, name, cond) } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql) + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index 8bb55c16f52e5..76cdc74ea1d11 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -35,8 +35,12 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/prepared.go b/executor/prepared.go index f57a1806b1ed9..9f8a427b84afa 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -116,6 +116,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { err error ) if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok { + // FIXME: ok... yet another parse API, may need some api interface clean. stmts, err = sqlParser.ParseSQL(e.sqlText, charset, collation) } else { p := parser.New() diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index 3d0752e08463d..5783438813954 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -37,8 +37,12 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { - sql := "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist" - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") + if err != nil { + return err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } diff --git a/executor/revoke.go b/executor/revoke.go index 7e51aa8ac82a4..fb722c89acf40 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -15,7 +15,7 @@ package executor import ( "context" - "fmt" + "strings" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -73,7 +73,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { if !isCommit { - _, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback") + _, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback") if err != nil { logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err)) } @@ -81,7 +81,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { e.releaseSysSession(internalSession) }() - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin") if err != nil { return err } @@ -103,7 +103,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } } - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit") + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit") if err != nil { return err } @@ -166,12 +166,15 @@ func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.Pr } func (e *RevokeExec) revokeGlobalPriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error { - asgns, err := composeGlobalPrivUpdate(priv.Priv, "N") + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) + err := composeGlobalPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user, host) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%?", user, host) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -180,12 +183,16 @@ func (e *RevokeExec) revokeDBPriv(internalSession sessionctx.Context, priv *ast. if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB } - asgns, err := composeDBPrivUpdate(priv.Priv, "N") + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable) + err := composeDBPrivUpdate(sql, priv.Priv, "N") if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, userName, host, dbName) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", userName, host, dbName) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -194,12 +201,16 @@ func (e *RevokeExec) revokeTablePriv(internalSession sessionctx.Context, priv *a if err != nil { return err } - asgns, err := composeTablePrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O) + + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable) + err = composeTablePrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user, host, dbName, tbl.Meta().Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user, host, dbName, tbl.Meta().Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) return err } @@ -208,20 +219,80 @@ func (e *RevokeExec) revokeColumnPriv(internalSession sessionctx.Context, priv * if err != nil { return err } + sql := new(strings.Builder) for _, c := range priv.Cols { col := table.FindCol(tbl.Cols(), c.Name.L) if col == nil { return errors.Errorf("Unknown column: %s", c) } - asgns, err := composeColumnPrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + sql.Reset() + sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable) + err = composeColumnPrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O) + if err != nil { + return err + } + sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user, host, dbName, tbl.Meta().Name.O, col.Name.O) + + _, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String()) + if err != nil { + return err + } + } + return nil +} + +func privUpdateForRevoke(cur []string, priv mysql.PrivilegeType) ([]string, error) { + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %v", priv) + } + cur = deleteFromSet(cur, p) + return cur, nil +} + +func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error { + var newTablePriv, newColumnPriv []string + + if priv != mysql.AllPriv { + currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl) + if err != nil { + return err + } + + newTablePriv = setFromString(currTablePriv) + newTablePriv, err = privUpdateForRevoke(newTablePriv, priv) + if err != nil { + return err + } + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) + if err != nil { + return err + } + } + + sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String()) + return nil +} + +func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error { + var newColumnPriv []string + + if priv != mysql.AllPriv { + currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) if err != nil { return err } - sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user, host, dbName, tbl.Meta().Name.O, col.Name.O) - _, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + + newColumnPriv = setFromString(currColumnPriv) + newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv) if err != nil { return err } } + + sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ",")) return nil } diff --git a/executor/show.go b/executor/show.go index bb856ac7657e0..8d4095756b137 100644 --- a/executor/show.go +++ b/executor/show.go @@ -284,12 +284,17 @@ func (e *ShowExec) fetchShowBind() error { } func (e *ShowExec) fetchShowEngines() error { - sql := `SELECT * FROM information_schema.engines` - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { e.result.AppendRow(row) } @@ -410,16 +415,32 @@ func (e *ShowExec) fetchShowTableStatus() error { return ErrBadDB.GenWithStackByArgs(e.DBName) } - sql := fmt.Sprintf(`SELECT - table_name, engine, version, row_format, table_rows, - avg_row_length, data_length, max_data_length, index_length, - data_free, auto_increment, create_time, update_time, check_time, - table_collation, IFNULL(checksum,''), create_options, table_comment - FROM information_schema.tables - WHERE table_schema='%s' ORDER BY table_name`, e.DBName) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT + table_name, engine, version, row_format, table_rows, + avg_row_length, data_length, max_data_length, index_length, + data_free, auto_increment, create_time, update_time, check_time, + table_collation, IFNULL(checksum,''), create_options, table_comment + FROM information_schema.tables + WHERE table_schema=%? ORDER BY table_name`, e.DBName.L) + if err != nil { + return errors.Trace(err) + } + var snapshot uint64 + txn, err := e.ctx.Txn(false) + if err != nil { + return errors.Trace(err) + } + if txn.Valid() { + snapshot = txn.StartTS() + } + if e.ctx.GetSessionVars().SnapshotTS != 0 { + snapshot = e.ctx.GetSessionVars().SnapshotTS + } + + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return errors.Trace(err) } @@ -1199,22 +1220,32 @@ func (e *ShowExec) fetchShowCreateUser() error { } } - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, - mysql.SystemDB, mysql.UserTable, userName, hostName) - rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, hostName) + if err != nil { + return errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return errors.Trace(err) } + if len(rows) == 0 { + // FIXME: the error returned is not escaped safely return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - sql = fmt.Sprintf(`SELECT PRIV FROM %s.%s WHERE User='%s' AND Host='%s'`, - mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) - rows, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + + stmt, err = exec.ParseWithParams(context.TODO(), `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } + rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + if err != nil { + return errors.Trace(err) + } + require := "NONE" if len(rows) == 1 { privData := rows[0].GetString(0) @@ -1225,6 +1256,7 @@ func (e *ShowExec) fetchShowCreateUser() error { } require = privValue.RequireStr() } + // FIXME: the returned string is not escaped safely showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE %s PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), require) e.appendRow([]interface{}{showStr}) diff --git a/executor/simple.go b/executor/simple.go index 2f1dca0cb982d..9a91191035cbd 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -80,7 +80,7 @@ func (e *baseExecutor) releaseSysSession(ctx sessionctx.Context) { } dom := domain.GetDomain(e.ctx) sysSessionPool := dom.SysSessionPool() - if _, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), "rollback"); err != nil { ctx.(pools.Resource).Close() return } @@ -151,23 +151,25 @@ func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, u := range s.UserList { if u.Hostname == "" { u.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", u.Username, u.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -199,42 +201,45 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } for _, role := range s.RoleList { - sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username) checker := privilege.GetPrivilegeManager(e.ctx) ok := checker.FindEdge(e.ctx, role, user) if ok { - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values(%?, %?, %?, %?);", user.Hostname, user.Username, role.Hostname, role.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } else { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -256,31 +261,34 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { } defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { if user.Hostname == "" { user.Hostname = "%" } - sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -288,29 +296,10 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (err error) { checker := privilege.GetPrivilegeManager(e.ctx) - user, sql := s.UserList[0], "" + user := s.UserList[0] if user.Hostname == "" { user.Hostname = "%" } - switch s.SetRoleOpt { - case ast.SetRoleNone: - sql = fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - case ast.SetRoleAll: - sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ - "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) - case ast.SetRoleRegular: - sql = "INSERT IGNORE INTO mysql.default_roles values" - for i, role := range s.RoleList { - ok := checker.FindEdge(e.ctx, role, user) - if !ok { - return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) - } - sql += fmt.Sprintf("('%s', '%s', '%s', '%s')", user.Hostname, user.Username, role.Hostname, role.Username) - if i != len(s.RoleList)-1 { - sql += "," - } - } - } restrictedCtx, err := e.getSysSession() if err != nil { @@ -319,27 +308,48 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } - deleteSQL := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) - if _, err := sqlExecutor.Execute(context.Background(), deleteSQL); err != nil { + sql := new(strings.Builder) + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + switch s.SetRoleOpt { + case ast.SetRoleNone: + sqlexec.MustFormatSQL(sql, "DELETE IGNORE FROM mysql.default_roles WHERE USER=%? AND HOST=%?;", user.Username, user.Hostname) + case ast.SetRoleAll: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST=%? AND TO_USER=%?;", user.Hostname, user.Username) + case ast.SetRoleRegular: + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO mysql.default_roles values") + for i, role := range s.RoleList { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + ok := checker.FindEdge(e.ctx, role, user) + if !ok { + return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) + } + sqlexec.MustFormatSQL(sql, "(%?, %?, %?, %?)", user.Hostname, user.Username, role.Hostname, role.Username) + } + } + + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } return nil @@ -595,16 +605,17 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } + sql := new(strings.Builder) for _, user := range s.Users { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { return errors.Trace(err) } if !exists { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", user.String()) @@ -613,23 +624,26 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { if role.Hostname == "" { role.Hostname = "%" } - sql := fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE FROM_HOST='%s' and FROM_USER='%s' and TO_HOST='%s' and TO_USER='%s'`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE FROM_HOST=%? and FROM_USER=%? and TO_HOST=%? and TO_USER=%?`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } - sql = fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE DEFAULT_ROLE_HOST='%s' and DEFAULT_ROLE_USER='%s' and HOST='%s' and USER='%s'`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE IGNORE FROM %n.%n WHERE DEFAULT_ROLE_HOST=%? and DEFAULT_ROLE_USER=%? and HOST=%? and USER=%?`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return errors.Trace(err) } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -687,9 +701,18 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm return err } - users := make([]string, 0, len(s.Specs)) - privs := make([]string, 0, len(s.Specs)) + sql := new(strings.Builder) + if s.IsCreateRole { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string, Account_locked) VALUES `, mysql.SystemDB, mysql.UserTable) + } else { + sqlexec.MustFormatSQL(sql, `INSERT INTO %n.%n (Host, User, authentication_string) VALUES `, mysql.SystemDB, mysql.UserTable) + } + + users := make([]*auth.UserIdentity, 0, len(s.Specs)) for _, spec := range s.Specs { + if len(users) > 0 { + sqlexec.MustFormatSQL(sql, ",") + } exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { return err1 @@ -710,26 +733,17 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if !ok { return errors.Trace(ErrPasswordFormat) } - user := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, pwd) if s.IsCreateRole { - user = fmt.Sprintf(`('%s', '%s', '%s', 'Y')`, spec.User.Hostname, spec.User.Username, pwd) - } - users = append(users, user) - - if len(privData) != 0 { - priv := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, hack.String(privData)) - privs = append(privs, priv) + sqlexec.MustFormatSQL(sql, `(%?, %?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd, "Y") + } else { + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, spec.User.Hostname, spec.User.Username, pwd) } + users = append(users, spec.User) } if len(users) == 0 { return nil } - sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - if s.IsCreateRole { - sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, authentication_string, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) - } - restrictedCtx, err := e.getSysSession() if err != nil { return err @@ -737,27 +751,34 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm defer e.releaseSysSession(restrictedCtx) sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return errors.Trace(err) } - _, err = sqlExecutor.Execute(context.Background(), sql) + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } - if len(privs) != 0 { - sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (Host, User, Priv) VALUES %s", mysql.SystemDB, mysql.GlobalPrivTable, strings.Join(privs, ", ")) - _, err = sqlExecutor.Execute(context.Background(), sql) + if len(privData) != 0 { + sql.Reset() + sqlexec.MustFormatSQL(sql, "INSERT IGNORE INTO %n.%n (Host, User, Priv) VALUES ", mysql.SystemDB, mysql.GlobalPrivTable) + for i, user := range users { + if i > 0 { + sqlexec.MustFormatSQL(sql, ",") + } + sqlexec.MustFormatSQL(sql, `(%?, %?, %?)`, user.Hostname, user.Username, string(hack.String(privData))) + } + _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()) if err != nil { - if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + if _, rollbackErr := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); rollbackErr != nil { return rollbackErr } return err } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return errors.Trace(err) } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -806,17 +827,22 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if !ok { return errors.Trace(ErrPasswordFormat) } - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string = '%s' WHERE Host = '%s' and User = '%s';`, - mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } if len(privData) > 0 { - sql = fmt.Sprintf("INSERT INTO %s.%s (Host, User, Priv) VALUES ('%s','%s','%s') ON DUPLICATE KEY UPDATE Priv = values(Priv)", - mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, hack.String(privData)) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err = exec.ParseWithParams(context.TODO(), "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } @@ -880,23 +906,25 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) // begin a transaction to insert role graph edges. - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.Users { for _, role := range s.Roles { - sql := fmt.Sprintf(`INSERT IGNORE INTO %s.%s (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ('%s','%s','%s','%s')`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `INSERT IGNORE INTO %n.%n (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES (%?,%?,%?,%?)`, mysql.SystemDB, mysql.RoleEdgeTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { logutil.BgLogger().Error(fmt.Sprintf("Error occur when executing %s", sql)) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } return ErrCannotUser.GenWithStackByArgs("GRANT ROLE", user.String()) } } } - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) @@ -933,10 +961,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } sqlExecutor := sysSession.(sqlexec.SQLExecutor) - if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "begin"); err != nil { return err } + sql := new(strings.Builder) for _, user := range s.UserList { exists, err := userExists(e.ctx, user.Username, user.Hostname) if err != nil { @@ -952,58 +981,66 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } // begin a transaction to delete a user. - sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.global_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) - if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } continue } // delete privileges from mysql.db - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete privileges from mysql.tables_priv - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE Host = %? and User = %?;`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.role_edges - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE TO_HOST = '%s' and TO_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE TO_HOST = %? and TO_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE FROM_HOST = '%s' and FROM_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE FROM_HOST = %? and FROM_USER = %?;`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } // delete relationship from mysql.default_roles - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE DEFAULT_ROLE_HOST = '%s' and DEFAULT_ROLE_USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE DEFAULT_ROLE_HOST = %? and DEFAULT_ROLE_USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } - sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE HOST = '%s' and USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { + sql.Reset() + sqlexec.MustFormatSQL(sql, `DELETE FROM %n.%n WHERE HOST = %? and USER = %?;`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) + if _, err = sqlExecutor.ExecuteInternal(context.TODO(), sql.String()); err != nil { failedUsers = append(failedUsers, user.String()) break } @@ -1011,11 +1048,11 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } if len(failedUsers) == 0 { - if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } } else { - if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "rollback"); err != nil { return err } if s.IsDropRole { @@ -1028,8 +1065,12 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } func userExists(ctx sessionctx.Context, name string, host string) (bool, error) { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, name, host) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host) + if err != nil { + return false, err + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return false, err } @@ -1062,8 +1103,12 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { } // update mysql.user - sql := fmt.Sprintf(`UPDATE %s.%s SET authentication_string='%s' WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) - _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + exec := e.ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) + if err != nil { + return err + } + _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } diff --git a/executor/trace.go b/executor/trace.go index bf9150f357081..8eb6592c4822f 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -132,24 +132,21 @@ func (e *TraceExec) nextRowJSON(ctx context.Context, se sqlexec.SQLExecutor, req } func (e *TraceExec) executeChild(ctx context.Context, se sqlexec.SQLExecutor) { - recordSets, err := se.Execute(ctx, e.stmtNode.Text()) - if len(recordSets) == 0 { - if err != nil { - var errCode uint16 - if te, ok := err.(*terror.Error); ok { - errCode = terror.ToSQLError(te).Code - } - logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) - } else { - logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) + rs, err := se.ExecuteStmt(ctx, e.stmtNode) + if err != nil { + var errCode uint16 + if te, ok := err.(*terror.Error); ok { + errCode = terror.ToSQLError(te).Code } + logutil.Eventf(ctx, "execute with error(%d): %s", errCode, err.Error()) } - for _, rs := range recordSets { + if rs != nil { drainRecordSet(ctx, e.ctx, rs) if err = rs.Close(); err != nil { logutil.Logger(ctx).Error("run trace close result with error", zap.Error(err)) } } + logutil.Eventf(ctx, "execute done, modify row: %d", e.ctx.GetSessionVars().StmtCtx.AffectedRows()) } func drainRecordSet(ctx context.Context, sctx sessionctx.Context, rs sqlexec.RecordSet) { diff --git a/executor/utils.go b/executor/utils.go new file mode 100644 index 0000000000000..fbc9ab4dcff30 --- /dev/null +++ b/executor/utils.go @@ -0,0 +1,46 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import "strings" + +func setFromString(value string) []string { + if len(value) == 0 { + return nil + } + return strings.Split(value, ",") +} + +// addToSet add a value to the set, e.g: +// addToSet("Select,Insert,Update", "Update") returns "Select,Insert,Update". +func addToSet(set []string, value string) []string { + for _, v := range set { + if v == value { + return set + } + } + return append(set, value) +} + +// deleteFromSet delete the value from the set, e.g: +// deleteFromSet("Select,Insert,Update", "Update") returns "Select,Insert". +func deleteFromSet(set []string, value string) []string { + for i, v := range set { + if v == value { + copy(set[i:], set[i+1:]) + return set[:len(set)-1] + } + } + return set +} diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index 318e1ba63f48f..eb380e9503ac7 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -14,6 +14,7 @@ package telemetry import ( + "context" "regexp" "sort" "strings" @@ -66,7 +67,12 @@ func normalizeFieldName(name string) string { } func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index fdb1be6bafc27..46b8cfb8f7b47 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -14,6 +14,8 @@ package telemetry import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/sqlexec" @@ -33,7 +35,12 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + exec := ctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + if err != nil { + return nil, errors.Trace(err) + } + rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return nil, errors.Trace(err) } diff --git a/util/mock/context.go b/util/mock/context.go index f049a05beb842..e2461c7bc8446 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -20,6 +20,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/owner" @@ -61,6 +62,11 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, return nil, errors.Errorf("Not Support.") } +// ExecuteStmt implements sqlexec.SQLExecutor ExecuteStmt interface. +func (c *Context) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") +} + // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { return nil, errors.Errorf("Not Supported.") diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index cc23bda405b9f..597873d050151 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -89,6 +89,7 @@ type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error) + ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (RecordSet, error) } // SQLParser is an interface provides parsing sql statement. From e889a5cc40fcb3616f88a1d6da319e55ed7c6cf6 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 18:23:37 +0800 Subject: [PATCH 35/85] executor: fix get var expr when session var is hex literal (#23241) (#23372) --- expression/builtin_control_test.go | 2 +- expression/builtin_string_test.go | 2 +- expression/constant_test.go | 4 +- planner/core/common_plans.go | 22 +++++++++ planner/core/expression_rewriter.go | 3 +- planner/core/integration_test.go | 70 +++++++++++++++++++++++++++++ planner/core/prepare_test.go | 12 +++++ types/field_type.go | 4 +- 8 files changed, 112 insertions(+), 7 deletions(-) diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 6ea1655e1c874..7f6e35aaa8626 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -116,7 +116,7 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) { {tm, nil, tm, false, false}, {nil, duration, duration, false, false}, {nil, types.NewDecFromFloatForTest(123.123), types.NewDecFromFloatForTest(123.123), false, false}, - {nil, types.NewBinaryLiteralFromUint(0x01, -1), uint64(1), false, false}, + {nil, types.NewBinaryLiteralFromUint(0x01, -1), "\x01", false, false}, {nil, types.Set{Value: 1, Name: "abc"}, "abc", false, false}, {nil, jsonInt.GetMysqlJSON(), jsonInt.GetMysqlJSON(), false, false}, {"abc", nil, "abc", false, false}, diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 8821133082836..67d81a0ae277c 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1149,7 +1149,7 @@ func (s *testEvaluatorSuite) TestHexFunc(c *C) { {-1, false, false, "FFFFFFFFFFFFFFFF"}, {-12.3, false, false, "FFFFFFFFFFFFFFF4"}, {-12.8, false, false, "FFFFFFFFFFFFFFF3"}, - {types.NewBinaryLiteralFromUint(0xC, -1), false, false, "C"}, + {types.NewBinaryLiteralFromUint(0xC, -1), false, false, "0C"}, {0x12, false, false, "12"}, {nil, true, false, ""}, {errors.New("must err"), false, true, ""}, diff --git a/expression/constant_test.go b/expression/constant_test.go index 50ea87d010f7b..49929497069cf 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -271,8 +271,8 @@ func (*testExpressionSuite) TestDeferredParamNotNull(c *C) { c.Assert(mysql.TypeTimestamp, Equals, cstTime.GetType().Tp) c.Assert(mysql.TypeDuration, Equals, cstDuration.GetType().Tp) c.Assert(mysql.TypeBlob, Equals, cstBytes.GetType().Tp) - c.Assert(mysql.TypeBit, Equals, cstBinary.GetType().Tp) - c.Assert(mysql.TypeBit, Equals, cstBit.GetType().Tp) + c.Assert(mysql.TypeVarString, Equals, cstBinary.GetType().Tp) + c.Assert(mysql.TypeVarString, Equals, cstBit.GetType().Tp) c.Assert(mysql.TypeFloat, Equals, cstFloat32.GetType().Tp) c.Assert(mysql.TypeDouble, Equals, cstFloat64.GetType().Tp) c.Assert(mysql.TypeEnum, Equals, cstEnum.GetType().Tp) diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 83d97ebecfd40..a19304db7c883 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -188,6 +188,21 @@ type Execute struct { Plan Plan } +// Check if result of GetVar expr is BinaryLiteral +// Because GetVar use String to represent BinaryLiteral, here we need to convert string back to BinaryLiteral. +func isGetVarBinaryLiteral(sctx sessionctx.Context, expr expression.Expression) (res bool) { + scalarFunc, ok := expr.(*expression.ScalarFunction) + if ok && scalarFunc.FuncName.L == ast.GetVar { + name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx, chunk.Row{}) + if err != nil || isNull { + res = false + } else if dt, ok2 := sctx.GetSessionVars().Users[name]; ok2 { + res = (dt.Kind() == types.KindBinaryLiteral) + } + } + return res +} + // OptimizePreparedPlan optimizes the prepared statement. func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error { vars := sctx.GetSessionVars() @@ -229,6 +244,13 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont return err } param := prepared.Params[i].(*driver.ParamMarkerExpr) + if isGetVarBinaryLiteral(sctx, usingVar) { + binVal, convErr := val.ToBytes() + if convErr != nil { + return convErr + } + val.SetBinaryLiteral(types.BinaryLiteral(binVal)) + } param.Datum = val param.InExecute = true vars.PreparedParams = append(vars.PreparedParams, val) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 0cb1317116762..df9b73a778178 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1356,7 +1356,8 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field er.ctxStackAppend(expression.NewNull(), types.EmptyName) return } - if leftEt == types.ETInt { + containMut := expression.ContainMutableConst(er.sctx, args) + if !containMut && leftEt == types.ETInt { for i := 1; i < len(args); i++ { if c, ok := args[i].(*expression.Constant); ok { var isExceptional bool diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index c83757a4dec6f..3396736350596 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1864,3 +1864,73 @@ func (s *testIntegrationSuite) TestIndexMergeTableFilter(c *C) { "10 1 1 10", )) } + +// #22949: test HexLiteral Used in GetVar expr +func (s *testIntegrationSuite) TestGetVarExprWithHexLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1_no_idx;") + tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));") + tk.MustExec("insert into t1_no_idx values(1, 0x3135);") + tk.MustExec("insert into t1_no_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use table with index on col_bit + tk.MustExec("drop table if exists t2_idx;") + tk.MustExec("create table t2_idx(id int, col_bit bit(16), key(col_bit));") + tk.MustExec("insert into t2_idx values(1, 0x3135);") + tk.MustExec("insert into t2_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit in (?)';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + tk.MustExec("set @a = 0x0F;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2")) + + // test col varchar with GetVar + tk.MustExec("drop table if exists t_varchar;") + tk.MustExec("create table t_varchar(id int, col_varchar varchar(100), key(col_varchar));") + tk.MustExec("insert into t_varchar values(1, '15');") + tk.MustExec("prepare stmt from 'select id from t_varchar where col_varchar = ?';") + tk.MustExec("set @a = 0x3135;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) +} + +// test BitLiteral used with GetVar +func (s *testIntegrationSuite) TestGetVarExprWithBitLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1_no_idx;") + tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));") + tk.MustExec("insert into t1_no_idx values(1, 0x3135);") + tk.MustExec("insert into t1_no_idx values(2, 0x0f);") + + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';") + // 0b11000100110101 is 0x3135 + tk.MustExec("set @a = 0b11000100110101;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) + + // same test, but use IN expr + tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';") + tk.MustExec("set @a = 0b11000100110101;") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) +} diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 92d1bef41d81f..ca1ddb8424768 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -626,6 +626,18 @@ func (s *testPrepareSerialSuite) TestConstPropAndPPDWithCache(c *C) { tk.MustQuery("execute stmt using @p0").Check(testkit.Rows( "0", )) + + // Need to check if contain mutable before RefineCompareConstant() in inToExpression(). + // Otherwise may hit wrong plan. + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 tinyint unsigned);") + tk.MustExec("insert into t1 values(111);") + tk.MustExec("prepare stmt from 'select 1 from t1 where c1 in (?)';") + tk.MustExec("set @a = '1.1';") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows()) + tk.MustExec("set @a = '111';") + tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) } func (s *testPlanSerialSuite) TestPlanCacheUnionScan(c *C) { diff --git a/types/field_type.go b/types/field_type.go index 9e470892875a1..0cc65b60f2e53 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -260,8 +260,8 @@ func DefaultTypeForValue(value interface{}, tp *FieldType, char string, collate tp.Flag |= mysql.UnsignedFlag SetBinChsClnFlag(tp) case BinaryLiteral: - tp.Tp = mysql.TypeBit - tp.Flen = len(x) * 8 + tp.Tp = mysql.TypeVarString + tp.Flen = len(x) tp.Decimal = 0 SetBinChsClnFlag(tp) tp.Flag &= ^mysql.BinaryFlag From f3791518d1660a988df83d86ed882ce4c4550463 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 19:01:36 +0800 Subject: [PATCH 36/85] expression: do not adjust int when it is null and compared year (#22821) (#22844) --- executor/executor_test.go | 9 +++++++++ expression/builtin_compare.go | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index e310abc017c96..0e77b813485c1 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6454,6 +6454,15 @@ func (s *testSuite) TestIssue20305(c *C) { tk.MustQuery("SELECT * FROM `t3` where y <= a").Check(testkit.Rows("2155 2156")) } +func (s *testSuite) TestIssue22817(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t3") + tk.MustExec("create table t3 (a year)") + tk.MustExec("insert into t3 values (1991), (\"1992\"), (\"93\"), (94)") + tk.MustQuery("select * from t3 where a >= NULL").Check(testkit.Rows()) +} + func (s *testSuite) TestIssue13953(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 671158a5e0138..72b71de7ac5ca 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1373,7 +1373,7 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express } } // int constant [cmp] year type - if arg0IsCon && arg0IsInt && arg1Type.Tp == mysql.TypeYear { + if arg0IsCon && arg0IsInt && arg1Type.Tp == mysql.TypeYear && !arg0.Value.IsNull() { adjusted, failed := types.AdjustYear(arg0.Value.GetInt64(), false) if failed == nil { arg0.Value.SetInt64(adjusted) @@ -1381,7 +1381,7 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express } } // year type [cmp] int constant - if arg1IsCon && arg1IsInt && arg0Type.Tp == mysql.TypeYear { + if arg1IsCon && arg1IsInt && arg0Type.Tp == mysql.TypeYear && !arg1.Value.IsNull() { adjusted, failed := types.AdjustYear(arg1.Value.GetInt64(), false) if failed == nil { arg1.Value.SetInt64(adjusted) From bb06d5dd240e1a87aef5d93a2e955d942b2d87a1 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 19:37:36 +0800 Subject: [PATCH 37/85] executor: wrong result of nullif expr when used with is null expr. (#23170) (#23279) --- expression/builtin_control.go | 3 +++ expression/integration_test.go | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 91e3cc526694e..2f0b409e6f5e7 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -67,6 +67,8 @@ func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { resultFieldType := &types.FieldType{} if lhs.Tp == mysql.TypeNull { *resultFieldType = *rhs + // If any of arg is NULL, result type need unset NotNullFlag. + types.SetTypeFlag(&resultFieldType.Flag, mysql.NotNullFlag, false) // If both arguments are NULL, make resulting type BINARY(0). if rhs.Tp == mysql.TypeNull { resultFieldType.Tp = mysql.TypeString @@ -75,6 +77,7 @@ func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { } } else if rhs.Tp == mysql.TypeNull { *resultFieldType = *lhs + types.SetTypeFlag(&resultFieldType.Flag, mysql.NotNullFlag, false) } else { resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs}) evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &resultFieldType.Flag) diff --git a/expression/integration_test.go b/expression/integration_test.go index 1cebcc6cbedbb..328d1306e0fbb 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3733,6 +3733,17 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) { result.Check(testkit.Rows("0")) } +// #23157: make sure if Nullif expr is correct combined with IsNull expr. +func (s *testIntegrationSuite) TestNullifWithIsNull(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null);") + tk.MustExec("insert into t values(1),(2);") + rows := tk.MustQuery("select * from t where nullif(a,a) is null;") + rows.Check(testkit.Rows("1", "2")) +} + func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) From 6969242b33cebaa0ad27a102a3a6913dba6de41a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 19:51:36 +0800 Subject: [PATCH 38/85] planner: fix range partition prune bug for IN expr (#22894) (#22938) (#23074) --- executor/executor_test.go | 17 +++++++++++++++++ planner/core/rule_partition_processor.go | 11 ++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index 0e77b813485c1..1c053b33c1c7d 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -4064,6 +4064,23 @@ func (s *testSuiteP1) TestSelectPartition(c *C) { tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) tk.MustExec("commit") tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) + + // test partition function is scalar func + tk.MustExec("drop table if exists tscalar") + tk.MustExec(`create table tscalar (c1 int) partition by range (c1 % 30) ( + partition p0 values less than (0), + partition p1 values less than (10), + partition p2 values less than (20), + partition pm values less than (maxvalue));`) + tk.MustExec("insert into tscalar values(0), (10), (40), (50), (55)") + // test IN expression + tk.MustExec("insert into tscalar values(-0), (-10), (-40), (-50), (-55)") + tk.MustQuery("select * from tscalar where c1 in (55, 55)").Check(testkit.Rows("55")) + tk.MustQuery("select * from tscalar where c1 in (40, 40)").Check(testkit.Rows("40")) + tk.MustQuery("select * from tscalar where c1 in (40)").Check(testkit.Rows("40")) + tk.MustQuery("select * from tscalar where c1 in (-40)").Check(testkit.Rows("-40")) + tk.MustQuery("select * from tscalar where c1 in (-40, -40)").Check(testkit.Rows("-40")) + tk.MustQuery("select * from tscalar where c1 in (-1)").Check(testkit.Rows()) } func (s *testSuiteP1) TestDeletePartition(c *C) { diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 12b7d53ab9f80..cc570d2577e84 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -509,7 +509,16 @@ func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expressi default: return pruner.fullRange() } - val, err := constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + + var val int64 + var err error + if pruner.partFn != nil { + // replace fn(col) to fn(const) + partFnConst := replaceColumnWithConst(pruner.partFn, constExpr) + val, _, err = partFnConst.EvalInt(sctx, chunk.Row{}) + } else { + val, err = constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + } if err != nil { return pruner.fullRange() } From a3648ac3e0ca3523c8037f941213663ed9842c49 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Thu, 18 Mar 2021 20:25:36 +0800 Subject: [PATCH 39/85] planner: fix query range partition table got wrong result and TiDB panic #22953 (#23325) --- cmd/explaintest/r/partition_pruning.result | 158 +++++---------------- expression/partition_pruner_test.go | 11 ++ planner/core/rule_partition_processor.go | 9 +- 3 files changed, 57 insertions(+), 121 deletions(-) diff --git a/cmd/explaintest/r/partition_pruning.result b/cmd/explaintest/r/partition_pruning.result index d8d024a551dd6..921ed0012593f 100644 --- a/cmd/explaintest/r/partition_pruning.result +++ b/cmd/explaintest/r/partition_pruning.result @@ -3859,16 +3859,9 @@ TableReader_8 10.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t2 where a=0; id estRows task access object operator info -PartitionUnion_9 30.00 root -├─TableReader_12 10.00 root data:Selection_11 -│ └─Selection_11 10.00 cop[tikv] eq(test.t2.a, 0) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 10.00 root data:Selection_14 -│ └─Selection_14 10.00 cop[tikv] eq(test.t2.a, 0) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 10.00 root data:Selection_17 - └─Selection_17 10.00 cop[tikv] eq(test.t2.a, 0) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 10.00 root data:Selection_7 +└─Selection_7 10.00 cop[tikv] eq(test.t2.a, 0) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo explain select * from t1 where a=0xFE; id estRows task access object operator info TableReader_8 10.00 root data:Selection_7 @@ -3876,31 +3869,17 @@ TableReader_8 10.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo explain select * from t2 where a=0xFE; id estRows task access object operator info -PartitionUnion_9 30.00 root -├─TableReader_12 10.00 root data:Selection_11 -│ └─Selection_11 10.00 cop[tikv] eq(test.t2.a, 254) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 10.00 root data:Selection_14 -│ └─Selection_14 10.00 cop[tikv] eq(test.t2.a, 254) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 10.00 root data:Selection_17 - └─Selection_17 10.00 cop[tikv] eq(test.t2.a, 254) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 10.00 root data:Selection_7 +└─Selection_7 10.00 cop[tikv] eq(test.t2.a, 254) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a > 0xFE AND a <= 0xFF; id estRows task access object operator info TableDual_6 0.00 root rows:0 explain select * from t2 where a > 0xFE AND a <= 0xFF; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a >= 0xFE AND a <= 0xFF; id estRows task access object operator info TableReader_8 250.00 root data:Selection_7 @@ -3908,16 +3887,9 @@ TableReader_8 250.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo explain select * from t2 where a >= 0xFE AND a <= 0xFF; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255) + └─TableFullScan_6 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo explain select * from t1 where a < 64 AND a >= 63; id estRows task access object operator info TableReader_8 250.00 root data:Selection_7 @@ -3925,16 +3897,13 @@ TableReader_8 250.00 root data:Selection_7 └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t2 where a < 64 AND a >= 63; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +PartitionUnion_8 500.00 root +├─TableReader_11 250.00 root data:Selection_10 +│ └─Selection_10 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) +│ └─TableFullScan_9 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo +└─TableReader_14 250.00 root data:Selection_13 + └─Selection_13 250.00 cop[tikv] ge(test.t2.a, 63), lt(test.t2.a, 64) + └─TableFullScan_12 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo explain select * from t1 where a <= 64 AND a >= 63; id estRows task access object operator info PartitionUnion_8 500.00 root @@ -3946,16 +3915,13 @@ PartitionUnion_8 500.00 root └─TableFullScan_12 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo explain select * from t2 where a <= 64 AND a >= 63; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) - └─TableFullScan_16 10000.00 cop[tikv] table:t2, partition:p2 keep order:false, stats:pseudo +PartitionUnion_8 500.00 root +├─TableReader_11 250.00 root data:Selection_10 +│ └─Selection_10 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) +│ └─TableFullScan_9 10000.00 cop[tikv] table:t2, partition:p0 keep order:false, stats:pseudo +└─TableReader_14 250.00 root data:Selection_13 + └─Selection_13 250.00 cop[tikv] ge(test.t2.a, 63), le(test.t2.a, 64) + └─TableFullScan_12 10000.00 cop[tikv] table:t2, partition:p1 keep order:false, stats:pseudo drop table t1; drop table t2; create table t1(a bigint unsigned not null) partition by range(a+0) ( @@ -3969,35 +3935,13 @@ insert into t1 values (9),(19),(0xFFFF0000FFFF000-1), (0xFFFF0000FFFFFFF-1); explain select * from t1 where a >= 2305561538531885056-10 and a <= 2305561538531885056-8; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 2305561538531885046), le(test.t1.a, 2305561538531885048) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo explain select * from t1 where a > 0xFFFFFFFFFFFFFFEC and a < 0xFFFFFFFFFFFFFFEE; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] gt(test.t1.a, 18446744073709551596), lt(test.t1.a, 18446744073709551598) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableDual_6 0.00 root rows:0 explain select * from t1 where a>=0 and a <= 0xFFFFFFFFFFFFFFFF; id estRows task access object operator info PartitionUnion_10 13293.33 root @@ -4023,19 +3967,9 @@ partition p4 values less than (1000) insert into t1 values (-15),(-5),(5),(15),(-15),(-5),(5),(15); explain select * from t1 where a>-2 and a <=0; id estRows task access object operator info -PartitionUnion_10 1000.00 root -├─TableReader_13 250.00 root data:Selection_12 -│ └─Selection_12 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_11 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -├─TableReader_16 250.00 root data:Selection_15 -│ └─Selection_15 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_14 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo -├─TableReader_19 250.00 root data:Selection_18 -│ └─Selection_18 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) -│ └─TableFullScan_17 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo -└─TableReader_22 250.00 root data:Selection_21 - └─Selection_21 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) - └─TableFullScan_20 10000.00 cop[tikv] table:t1, partition:p4 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] gt(test.t1.a, -2), le(test.t1.a, 0) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p3 keep order:false, stats:pseudo drop table t1; CREATE TABLE t1 ( recdate DATETIME NOT NULL ) PARTITION BY RANGE( TO_DAYS(recdate) ) ( @@ -4086,28 +4020,14 @@ partition p2 values less than (255) insert into t1 select A.a + 10*B.a from t0 A, t0 B; explain select * from t1 where a between 10 and 13; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) - └─TableFullScan_16 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 13) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo explain select * from t1 where a between 10 and 10+33; id estRows task access object operator info -PartitionUnion_9 750.00 root -├─TableReader_12 250.00 root data:Selection_11 -│ └─Selection_11 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) -│ └─TableFullScan_10 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo -├─TableReader_15 250.00 root data:Selection_14 -│ └─Selection_14 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) -│ └─TableFullScan_13 10000.00 cop[tikv] table:t1, partition:p1 keep order:false, stats:pseudo -└─TableReader_18 250.00 root data:Selection_17 - └─Selection_17 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) - └─TableFullScan_16 10000.00 cop[tikv] table:t1, partition:p2 keep order:false, stats:pseudo +TableReader_8 250.00 root data:Selection_7 +└─Selection_7 250.00 cop[tikv] ge(test.t1.a, 10), le(test.t1.a, 43) + └─TableFullScan_6 10000.00 cop[tikv] table:t1, partition:p0 keep order:false, stats:pseudo drop table t0, t1; drop table if exists t; create table t(a timestamp) partition by range(unix_timestamp(a)) (partition p0 values less than(unix_timestamp('2019-02-16 14:20:00')), partition p1 values less than (maxvalue)); diff --git a/expression/partition_pruner_test.go b/expression/partition_pruner_test.go index a132b6b84377b..d54f3b3e5f05d 100644 --- a/expression/partition_pruner_test.go +++ b/expression/partition_pruner_test.go @@ -86,3 +86,14 @@ func (s *testSuite2) TestHashPartitionPruner(c *C) { tk.MustQuery(tt).Check(testkit.Rows(output[i].Result...)) } } + +func (s *testSuite2) TestIssue22898(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("DROP TABLE IF EXISTS test;") + tk.MustExec("CREATE TABLE NT_RP3763 (COL1 TINYINT(8) SIGNED COMMENT \"NUMERIC NO INDEX\" DEFAULT 41,COL2 VARCHAR(20),COL3 DATETIME,COL4 BIGINT,COL5 FLOAT) PARTITION BY RANGE (COL1 * COL3) (PARTITION P0 VALUES LESS THAN (0),PARTITION P1 VALUES LESS THAN (10),PARTITION P2 VALUES LESS THAN (20),PARTITION P3 VALUES LESS THAN (30),PARTITION P4 VALUES LESS THAN (40),PARTITION P5 VALUES LESS THAN (50),PARTITION PMX VALUES LESS THAN MAXVALUE);") + tk.MustExec("insert into NT_RP3763 (COL1,COL2,COL3,COL4,COL5) values(-82,\"夐齏醕皆磹漋甓崘潮嵙燷渏艂朼洛炷鉢儝鱈肇\",\"5748\\-06\\-26\\ 20:48:49\",-3133527360541070260,-2.624880003397658e+38);") + tk.MustExec("insert into NT_RP3763 (COL1,COL2,COL3,COL4,COL5) values(48,\"簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢\",\"7228\\-12\\-13\\ 02:59:54\",-6181009269190017937,2.7731105531290494e+38);") + tk.MustQuery("select * from `NT_RP3763` where `COL1` in (10, 48, -82);").Check(testkit.Rows("-82 夐齏醕皆磹漋甓崘潮嵙燷渏艂朼洛炷鉢儝鱈肇 5748-06-26 20:48:49 -3133527360541070260 -262488000000000000000000000000000000000", "48 簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢 7228-12-13 02:59:54 -6181009269190017937 277311060000000000000000000000000000000")) + tk.MustQuery("select * from `NT_RP3763` where `COL1` in (48);").Check(testkit.Rows("48 簖鹩筈匹眜赖泽騈爷詵赺玡婙Ɇ郝鮙廛賙疼舢 7228-12-13 02:59:54 -6181009269190017937 277311060000000000000000000000000000000")) +} diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index cc570d2577e84..83536239f47e7 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -396,8 +396,10 @@ func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, args := fn.GetArgs() if len(args) > 0 { arg0 := args[0] - if c, ok1 := arg0.(*expression.Column); ok1 { - col = c + if expression.ExtractColumnSet(args).Len() == 1 { + if c, ok1 := arg0.(*expression.Column); ok1 { + col = c + } } } } @@ -533,6 +535,9 @@ func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expressi var monotoneIncFuncs = map[string]struct{}{ ast.ToDays: {}, ast.UnixTimestamp: {}, + // Only when the function form is fn(column, const) + ast.Plus: {}, + ast.Minus: {}, } // f(x) op const, op is > = < From aad2f7211466638f7c08b30556a94be21bcca924 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 20:41:36 +0800 Subject: [PATCH 40/85] planner: fix the bug that wrong collation is used when try fast path for enum or set (#23217) (#23292) --- expression/integration_test.go | 9 +++++++++ planner/core/point_get_plan.go | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index 328d1306e0fbb..614f06a51b838 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6097,6 +6097,15 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) { tk.MustQuery("select * from t_ci where a='A'").Check(testkit.Rows("a")) tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) tk.MustQuery("select * from t_ci where a='a '").Check(testkit.Rows("a")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c set('A', 'B') collate utf8mb4_general_ci);") + tk.MustExec("insert into t values('a');") + tk.MustExec("insert into t values('B');") + tk.MustQuery("select c from t where c = 'a';").Check(testkit.Rows("A")) + tk.MustQuery("select c from t where c = 'A';").Check(testkit.Rows("A")) + tk.MustQuery("select c from t where c = 'b';").Check(testkit.Rows("B")) + tk.MustQuery("select c from t where c = 'B';").Check(testkit.Rows("B")) } func (s *testIntegrationSerialSuite) TestWeightString(c *C) { diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index d78400cbc9648..d499a6cf81922 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -1040,7 +1040,8 @@ func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, } } // The converted result must be same as original datum. - cmp, err := d.CompareDatum(stmtCtx, &dVal) + // Compare them based on the dVal's type. + cmp, err := dVal.CompareDatum(stmtCtx, &d) if err != nil { return nil, false } else if cmp != 0 { From 11a9254bec522c1991033a6098f6dccadc6eb7c1 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 18 Mar 2021 22:27:36 +0800 Subject: [PATCH 41/85] executor: fix wrong key range of index scan when filter is comparing year column with NULL (#23079) (#23104) --- executor/executor_test.go | 26 +++++++ executor/testdata/executor_suite_in.json | 12 ++++ executor/testdata/executor_suite_out.json | 84 +++++++++++++++++++++++ types/datum.go | 4 ++ 4 files changed, 126 insertions(+) diff --git a/executor/executor_test.go b/executor/executor_test.go index 1c053b33c1c7d..8d77b5652acc1 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1134,6 +1134,32 @@ func (s *testSuiteP1) TestIssue5055(c *C) { result.Check(testkit.Rows("1 1")) } +// issue-23038: wrong key range of index scan for year column +func (s *testSuiteWithData) TestIndexScanWithYearCol(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (c1 year(4), c2 int, key(c1));") + tk.MustExec("insert into t values(2001, 1);") + + var input []string + var output []struct { + SQL string + Plan []string + Res []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain " + tt).Rows()) + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + }) + tk.MustQuery("explain " + tt).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...)) + } +} + func (s *testSuiteP2) TestUnion(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/executor/testdata/executor_suite_in.json b/executor/testdata/executor_suite_in.json index 960471c958458..cb15278ee9b9f 100644 --- a/executor/testdata/executor_suite_in.json +++ b/executor/testdata/executor_suite_in.json @@ -7,5 +7,17 @@ "select * from t1 natural right join t2 order by a", "SELECT * FROM t1 NATURAL LEFT JOIN t2 WHERE not(t1.a <=> t2.a)" ] + }, + { + "name": "TestIndexScanWithYearCol", + "cases": [ + "select t1.c1, t2.c1 from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select count(*) from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select t1.c1, t2.c1 from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select count(*) from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 is not NULL" + ] } ] diff --git a/executor/testdata/executor_suite_out.json b/executor/testdata/executor_suite_out.json index ef43a96afef6b..bb635bc2951f6 100644 --- a/executor/testdata/executor_suite_out.json +++ b/executor/testdata/executor_suite_out.json @@ -70,5 +70,89 @@ ] } ] + }, + { + "Name": "TestIndexScanWithYearCol", + "Cases": [ + { + "SQL": "select t1.c1, t2.c1 from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_8 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_24(Build) 0.00 root rows:0", + "└─TableDual_23(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select * from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_8 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_26(Build) 0.00 root rows:0", + "└─TableDual_25(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select count(*) from t as t1 inner join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "StreamAgg_10 1.00 root funcs:count(1)->Column#7", + "└─MergeJoin_11 0.00 root inner join, left key:test.t.c1, right key:test.t.c1", + " ├─TableDual_27(Build) 0.00 root rows:0", + " └─TableDual_26(Probe) 0.00 root rows:0" + ], + "Res": [ + "0" + ] + }, + { + "SQL": "select t1.c1, t2.c1 from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_7 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_17(Build) 0.00 root rows:0", + "└─TableDual_16(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "MergeJoin_7 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + "├─TableDual_18(Build) 0.00 root rows:0", + "└─TableDual_17(Probe) 0.00 root rows:0" + ], + "Res": [ + ] + }, + { + "SQL": "select count(*) from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 != NULL", + "Plan": [ + "StreamAgg_9 1.00 root funcs:count(1)->Column#7", + "└─MergeJoin_10 0.00 root left outer join, left key:test.t.c1, right key:test.t.c1", + " ├─TableDual_20(Build) 0.00 root rows:0", + " └─TableDual_19(Probe) 0.00 root rows:0" + ], + "Res": [ + "0" + ] + }, + { + "SQL": "select * from t as t1 left join t as t2 on t1.c1 = t2.c1 where t1.c1 is not NULL", + "Plan": [ + "HashJoin_15 12487.50 root left outer join, equal:[eq(test.t.c1, test.t.c1)]", + "├─TableReader_33(Build) 9990.00 root data:Selection_32", + "│ └─Selection_32 9990.00 cop[tikv] not(isnull(test.t.c1))", + "│ └─TableFullScan_31 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + "└─TableReader_27(Probe) 9990.00 root data:Selection_26", + " └─Selection_26 9990.00 cop[tikv] not(isnull(test.t.c1))", + " └─TableFullScan_25 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Res": [ + "2001 1 2001 1" + ] + } + ] } ] diff --git a/types/datum.go b/types/datum.go index 7a72024488006..715ebcb4d315f 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1427,6 +1427,10 @@ func (d *Datum) convertToMysqlFloatYear(sc *stmtctx.StatementContext, target *Fi y = float64(d.GetMysqlTime().Year()) case KindMysqlDuration: y = float64(time.Now().Year()) + case KindNull: + // if datum is NULL, we should keep it as it is, instead of setting it to zero or any other value. + ret = *d + return ret, nil default: ret, err = d.convertToFloat(sc, NewFieldType(mysql.TypeDouble)) if err != nil { From 30738358090700448b19d8c3f47e0d7feacede4d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 30 Apr 2021 12:25:50 +0800 Subject: [PATCH 42/85] planner: fix a bug that point get plan returns wrong column name (#23365) (#23970) --- planner/core/point_get_plan.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index d499a6cf81922..a2a4477dd2e74 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -919,7 +919,7 @@ func buildSchemaFromFields( if col == nil { return nil, nil } - asName := col.Name + asName := colNameExpr.Name.Name if field.AsName.L != "" { asName = field.AsName } @@ -927,6 +927,7 @@ func buildSchemaFromFields( DBName: dbName, OrigTblName: tbl.Name, TblName: tblName, + OrigColName: col.Name, ColName: asName, }) columns = append(columns, colInfoToColumn(col, len(columns))) From a8155caa54f663762bcffee2abb69c8390439f4f Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Fri, 30 Apr 2021 16:01:50 +0800 Subject: [PATCH 43/85] planner, executor: IndexMerge supports reading extraHandleCol in partialTableReader (#23572) --- executor/builder.go | 3 +++ executor/index_merge_reader.go | 10 +++++--- executor/index_merge_reader_test.go | 37 +++++++++++++++++++++++++++++ planner/core/physical_plans.go | 3 +++ planner/core/resolve_indices.go | 8 +++++++ planner/core/task.go | 2 +- 6 files changed, 59 insertions(+), 4 deletions(-) diff --git a/executor/builder.go b/executor/builder.go index e8f4bca522a8f..2462293a8d69c 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2610,6 +2610,9 @@ func buildNoRangeIndexMergeReader(b *executorBuilder, v *plannercore.PhysicalInd } collectTable := false e.tableRequest.CollectRangeCounts = &collectTable + if v.ExtraHandleCol != nil { + e.extraHandleIdx = v.ExtraHandleCol.Index + } return e, nil } diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index edf868913afd7..8e5e0c6a49f0d 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -102,6 +102,10 @@ type IndexMergeReaderExecutor struct { corColInAccess bool idxCols [][]*expression.Column colLens [][]int + + // extraHandleIdx indicates the index of extraHandleCol when the partial + // reader is TableReader. + extraHandleIdx int } // Open implements the Executor Open interface @@ -280,7 +284,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, var err error util.WithRecovery( func() { - _, err = worker.fetchHandles(ctx1, exitCh, fetchCh, e.resultCh, e.finished) + _, err = worker.fetchHandles(ctx1, exitCh, fetchCh, e.resultCh, e.finished, e.extraHandleIdx) }, e.handleHandlesFetcherPanic(ctx, e.resultCh, "partialTableWorker"), ) @@ -305,7 +309,7 @@ type partialTableWorker struct { } func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *lookupTableTask, resultCh chan<- *lookupTableTask, - finished <-chan struct{}) (count int64, err error) { + finished <-chan struct{}, extraHandleIdx int) (count int64, err error) { var chk *chunk.Chunk handleOffset := -1 if w.tableInfo.PKIsHandle { @@ -318,7 +322,7 @@ func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan str } } } else { - return 0, errors.Errorf("cannot find the column for handle") + handleOffset = extraHandleIdx } chk = chunk.NewChunkWithCapacity(retTypes(w.tableReader), w.maxChunkSize) diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index 647ab6911e358..06a39841cd7f9 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -14,8 +14,10 @@ package executor_test import ( + "fmt" . "github.com/pingcap/check" "github.com/pingcap/tidb/util/testkit" + "strings" ) func (s *testSuite1) TestSingleTableRead(c *C) { @@ -78,3 +80,38 @@ func (s *testSuite1) TestIssue16910(c *C) { tk.MustExec("insert into t2 values (0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9), (10, 10), (11, 11), (12, 12), (13, 13), (14, 14), (15, 15), (16, 16), (17, 17), (18, 18), (19, 19), (20, 20), (21, 21), (22, 22), (23, 23);") tk.MustQuery("select /*+ USE_INDEX_MERGE(t1, a, b) */ * from t1 partition (p0) join t2 partition (p1) on t1.a = t2.a where t1.a < 40 or t1.b < 30;").Check(testkit.Rows("1 1 1 1")) } + +func (s *testSuite1) TestIssue23569(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists tt;") + tk.MustExec(`create table tt(id bigint(20) NOT NULL,create_time bigint(20) NOT NULL DEFAULT '0' ,driver varchar(64), PRIMARY KEY (id,create_time)) + PARTITION BY RANGE ( create_time ) ( + PARTITION p201901 VALUES LESS THAN (1577808000), + PARTITION p202001 VALUES LESS THAN (1585670400), + PARTITION p202002 VALUES LESS THAN (1593532800), + PARTITION p202003 VALUES LESS THAN (1601481600), + PARTITION p202004 VALUES LESS THAN (1609430400), + PARTITION p202101 VALUES LESS THAN (1617206400), + PARTITION p202102 VALUES LESS THAN (1625068800), + PARTITION p202103 VALUES LESS THAN (1633017600), + PARTITION p202104 VALUES LESS THAN (1640966400), + PARTITION p202201 VALUES LESS THAN (1648742400), + PARTITION p202202 VALUES LESS THAN (1656604800), + PARTITION p202203 VALUES LESS THAN (1664553600), + PARTITION p202204 VALUES LESS THAN (1672502400), + PARTITION p202301 VALUES LESS THAN (1680278400) + );`) + tk.MustExec("insert tt value(1, 1577807000, 'jack'), (2, 1577809000, 'mike'), (3, 1585670500, 'right'), (4, 1601481500, 'hello');") + tk.MustExec("set @@tidb_enable_index_merge=true;") + rows := tk.MustQuery("explain select count(*) from tt partition(p202003) where _tidb_rowid is null or (_tidb_rowid>=1 and _tidb_rowid<100);").Rows() + containsIndexMerge := false + for _, r := range rows { + if strings.Contains(fmt.Sprintf("%s", r[0]), "IndexMerge") { + containsIndexMerge = true + break + } + } + c.Assert(containsIndexMerge, IsTrue) + tk.MustQuery("select count(*) from tt partition(p202003) where _tidb_rowid is null or (_tidb_rowid>=1 and _tidb_rowid<100);").Check(testkit.Rows("1")) +} diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 1220254d90e77..3f2ebbb42e4c7 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -189,6 +189,9 @@ type PhysicalIndexMergeReader struct { partialPlans []PhysicalPlan // tablePlan is a PhysicalTableScan to get the table tuples. Current, it must be not nil. tablePlan PhysicalPlan + // ExtraHandleCol indicates the index of extraHandleCol when the partial + // reader is TableReader. + ExtraHandleCol *expression.Column } // PhysicalIndexScan represents an index scan plan. diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index e37d51e57c1e0..974592a2701cc 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -365,6 +365,14 @@ func (p *PhysicalIndexMergeReader) ResolveIndices() (err error) { } } for i := 0; i < len(p.partialPlans); i++ { + switch x := p.partialPlans[i].(type) { + case *PhysicalTableReader: + newCol, err := p.ExtraHandleCol.ResolveIndices(x.Schema()) + if err != nil { + return err + } + p.ExtraHandleCol = newCol.(*expression.Column) + } err = p.partialPlans[i].ResolveIndices() if err != nil { return err diff --git a/planner/core/task.go b/planner/core/task.go index 93ba40e83e30d..3789a9f00dd50 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -725,7 +725,7 @@ func finishCopTask(ctx sessionctx.Context, task task) task { cst: t.cst, } if t.idxMergePartPlans != nil { - p := PhysicalIndexMergeReader{partialPlans: t.idxMergePartPlans, tablePlan: t.tablePlan}.Init(ctx, t.idxMergePartPlans[0].SelectBlockOffset()) + p := PhysicalIndexMergeReader{partialPlans: t.idxMergePartPlans, tablePlan: t.tablePlan, ExtraHandleCol: t.extraHandleCol}.Init(ctx, t.idxMergePartPlans[0].SelectBlockOffset()) setTableScanToTableRowIDScan(p.tablePlan) newTask.p = p return newTask From fd70039381164fb50d6eb17371ad6df799976756 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 30 Apr 2021 19:13:50 +0800 Subject: [PATCH 44/85] planner: do not push down to TiFlash if the table scan require to scan data in desc order (#23948) (#23974) --- planner/core/find_best_task.go | 3 ++ planner/core/integration_test.go | 37 +++++++++++++++++++ .../testdata/integration_serial_suite_in.json | 7 ++++ .../integration_serial_suite_out.json | 25 +++++++++++++ 4 files changed, 72 insertions(+) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 7e8f0587f9902..18d1e9b29170f 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -1184,6 +1184,9 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid return invalidTask, nil } ts, cost, _ := ds.getOriginalPhysicalTableScan(prop, candidate.path, candidate.isMatchProp) + if ts.KeepOrder && ts.Desc && ts.StoreType == kv.TiFlash { + return invalidTask, nil + } copTask := &copTask{ tablePlan: ts, indexPlanFinished: true, diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 3396736350596..48c96d008499d 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -349,6 +349,43 @@ func (s *testIntegrationSerialSuite) TestSelPushDownTiFlash(c *C) { } } +func (s *testIntegrationSerialSuite) TestPushDownToTiFlashWithKeepOrder(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b varchar(20))") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSerialSuite) TestBroadcastJoin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index f7a561e28c73f..9b39df74067b0 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -8,6 +8,13 @@ "explain select * from t where b > 'a' order by b limit 2" ] }, + { + "name": "TestPushDownToTiFlashWithKeepOrder", + "cases": [ + "explain select max(a) from t", + "explain select min(a) from t" + ] + }, { "name": "TestBroadcastJoin", "cases": [ diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index 99027be774531..9dff901e19d3f 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -42,6 +42,31 @@ } ] }, + { + "Name": "TestPushDownToTiFlashWithKeepOrder", + "Cases": [ + { + "SQL": "explain select max(a) from t", + "Plan": [ + "StreamAgg_9 1.00 root funcs:max(test.t.a)->Column#3", + "└─TopN_10 1.00 root test.t.a:desc, offset:0, count:1", + " └─TableReader_18 1.00 root data:TopN_17", + " └─TopN_17 1.00 batchCop[tiflash] test.t.a:desc, offset:0, count:1", + " └─TableFullScan_16 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain select min(a) from t", + "Plan": [ + "StreamAgg_9 1.00 root funcs:min(test.t.a)->Column#3", + "└─Limit_13 1.00 root offset:0, count:1", + " └─TableReader_23 1.00 root data:Limit_22", + " └─Limit_22 1.00 cop[tiflash] offset:0, count:1", + " └─TableFullScan_21 1.00 cop[tiflash] table:t keep order:true, stats:pseudo" + ] + } + ] + }, { "Name": "TestBroadcastJoin", "Cases": [ From ae80ec245442050b354b19918d34f0dc0aca9d28 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 30 Apr 2021 20:01:50 +0800 Subject: [PATCH 45/85] executor: skip lock key if lock key is empty (#23188) (#24312) --- executor/point_get.go | 3 +++ session/pessimistic_test.go | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/executor/point_get.go b/executor/point_get.go index ae8bbcfbb37d5..f9d784e95ce6c 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -271,6 +271,9 @@ func (e *PointGetExecutor) getAndLock(ctx context.Context, key kv.Key) (val []by } func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) error { + if len(key) == 0 { + return nil + } if e.lock { seVars := e.ctx.GetSessionVars() lockCtx := newLockCtx(seVars, e.lockWaitTime) diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index d68c9b06327a4..5d05ac08042a1 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -299,6 +299,13 @@ func (s *testPessimisticSuite) TestInsertOnDup(c *C) { tk.MustQuery("select * from dup").Check(testkit.Rows("1 2")) } +func (s *testPessimisticSuite) TestPointGetOverflow(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t(k tinyint, v int, unique key(k))") + tk.MustExec("begin pessimistic") + tk.MustExec("update t set v = 100 where k = -200;") +} + func (s *testPessimisticSuite) TestPointGetKeyLock(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk2 := testkit.NewTestKitWithInit(c, s.store) From c067ad1166c75ebc613f4f0f26065b8544b4983e Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 11:23:53 +0800 Subject: [PATCH 46/85] expression: fix wrong collation for `concat` function (#24297) (#24300) --- expression/builtin_cast.go | 6 +++++- expression/integration_test.go | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 3f21d4303c448..e10a7f5d43c74 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1899,7 +1899,11 @@ func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression { argLen = -1 } tp := types.NewFieldType(mysql.TypeVarString) - tp.Charset, tp.Collate = expr.CharsetAndCollation(ctx) + if expr.Coercibility() == CoercibilityExplicit { + tp.Charset, tp.Collate = expr.CharsetAndCollation(ctx) + } else { + tp.Charset, tp.Collate = ctx.GetSessionVars().GetCharsetInfo() + } tp.Flen, tp.Decimal = argLen, types.UnspecifiedLength return BuildCastFunction(ctx, expr, tp) } diff --git a/expression/integration_test.go b/expression/integration_test.go index 614f06a51b838..f3eb9e8d578e4 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8069,6 +8069,7 @@ func (s *testIntegrationSerialSuite) TestIssue19116(c *C) { defer collate.SetNewCollationEnabledForTest(false) tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") tk.MustQuery("select collation(concat(1 collate `binary`));").Check(testkit.Rows("binary")) tk.MustQuery("select coercibility(concat(1 collate `binary`));").Check(testkit.Rows("0")) @@ -8079,4 +8080,10 @@ func (s *testIntegrationSerialSuite) TestIssue19116(c *C) { tk.MustQuery("select collation(1);").Check(testkit.Rows("binary")) tk.MustQuery("select coercibility(1);").Check(testkit.Rows("5")) tk.MustQuery("select coercibility(1=1);").Check(testkit.Rows("5")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a datetime)") + tk.MustExec("insert into t values ('2020-02-02')") + tk.MustQuery("select collation(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select coercibility(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("4")) } From af24b2faefa4e37947ccb3815797ba73ff2e6a1c Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 11:39:53 +0800 Subject: [PATCH 47/85] statistics: fix the panic when analyze with collation enabled (#21262) (#21299) --- executor/analyze.go | 5 +++++ executor/analyze_test.go | 24 ++++++++++++++++++++++++ executor/executor_test.go | 5 +++++ statistics/histogram.go | 8 ++++++++ 4 files changed, 42 insertions(+) diff --git a/executor/analyze.go b/executor/analyze.go index bcea8925221ce..3e894b99c776f 100755 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -523,6 +523,11 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range) (hists []*statis if err != nil { return nil, nil, err } + // When collation is enabled, we store the Key representation of the sampling data. So we set it to kind `Bytes` here + // to avoid to convert it to its Key representation once more. + if collectors[i].Samples[j].Value.Kind() == types.KindString { + collectors[i].Samples[j].Value.SetBytes(collectors[i].Samples[j].Value.GetBytes()) + } } hg, err := statistics.BuildColumn(e.ctx, int64(e.opts[ast.AnalyzeOptNumBuckets]), col.ID, collectors[i], &col.FieldType) if err != nil { diff --git a/executor/analyze_test.go b/executor/analyze_test.go index b12646a4cb279..764b1c38ce692 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/testkit" ) @@ -646,3 +647,26 @@ func (s *testSuite1) TestDefaultValForAnalyze(c *C) { tk.MustQuery("explain select * from t where a = 1").Check(testkit.Rows("IndexReader_6 1.00 root index:IndexRangeScan_5", "└─IndexRangeScan_5 1.00 cop[tikv] table:t, index:a(a) range:[1,1], keep order:false")) } + +func (s *testSerialSuite2) TestIssue20874(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a char(10) collate utf8mb4_unicode_ci not null, b char(20) collate utf8mb4_general_ci not null)") + tk.MustExec("insert into t values ('#', 'C'), ('$', 'c'), ('a', 'a')") + tk.MustExec("analyze table t") + tk.MustQuery("show stats_buckets where db_name = 'test' and table_name = 't'").Sort().Check(testkit.Rows( + "test t a 0 0 1 1 \x02\xd2 \x02\xd2", + "test t a 0 1 2 1 \x0e\x0f \x0e\x0f", + "test t a 0 2 3 1 \x0e3 \x0e3", + "test t b 0 0 1 1 \x00A \x00A", + "test t b 0 1 3 2 \x00C \x00C", + )) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a char(10) collate utf8mb4_general_ci not null)") + tk.MustExec("insert into t values ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文'), ('汉字'), ('中文')") + tk.MustExec("analyze table t") +} diff --git a/executor/executor_test.go b/executor/executor_test.go index 8d77b5652acc1..2240bdc280bc0 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -104,6 +104,7 @@ var _ = Suite(&testSuite{&baseTestSuite{}}) var _ = Suite(&testSuiteP1{&baseTestSuite{}}) var _ = Suite(&testSuiteP2{&baseTestSuite{}}) var _ = Suite(&testSuite1{}) +var _ = SerialSuites(&testSerialSuite2{}) var _ = Suite(&testSuite2{&baseTestSuite{}}) var _ = Suite(&testSuite3{&baseTestSuite{}}) var _ = Suite(&testSuite4{&baseTestSuite{}}) @@ -3033,6 +3034,10 @@ type testSuite1 struct { testSuiteWithCliBase } +type testSerialSuite2 struct { + testSuiteWithCliBase +} + func (s *testSuiteWithCliBase) SetUpSuite(c *C) { cli := &checkRequestClient{} hijackClient := func(c tikv.Client) tikv.Client { diff --git a/statistics/histogram.go b/statistics/histogram.go index 66e8f9f3f3677..76cbfe1732d07 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -22,6 +22,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -89,6 +90,13 @@ type scalar struct { // NewHistogram creates a new histogram. func NewHistogram(id, ndv, nullCount int64, version uint64, tp *types.FieldType, bucketSize int, totColSize int64) *Histogram { + if tp.EvalType() == types.ETString { + // The histogram will store the string value's 'sort key' representation of its collation. + // If we directly set the field type's collation to its original one. We would decode the Key representation using its collation. + // This would cause panic. So we apply a little trick here to avoid decoding it by explicitly changing the collation to 'CollationBin'. + tp = tp.Clone() + tp.Collate = charset.CollationBin + } return &Histogram{ ID: id, NDV: ndv, From c21f6f27d2f478759e2b132cce279cac51a848d8 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 12:43:52 +0800 Subject: [PATCH 48/85] functions: fix some string function has wrong collation and flag (#23835) (#23878) --- expression/builtin_string.go | 22 +++++++++++----------- expression/typeinfer_test.go | 7 +++++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index d07536c42bee5..749d3453d0775 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -180,6 +180,11 @@ func SetBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) { } } +// addBinFlag add the binary flag to `tp` if its charset is binary +func addBinFlag(tp *types.FieldType) { + SetBinFlagOrBinStr(tp, tp) +} + type lengthFunctionClass struct { baseFunctionClass } @@ -275,10 +280,10 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if err != nil { return nil, err } + addBinFlag(bf.tp) bf.tp.Flen = 0 for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) if argType.Flen < 0 { bf.tp.Flen = mysql.MaxBlobWidth @@ -350,9 +355,9 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } bf.tp.Flen = 0 + addBinFlag(bf.tp) for i := range args { argType := args[i].GetType() - SetBinFlagOrBinStr(argType, bf.tp) // skip separator param if i != 0 { @@ -2000,8 +2005,7 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2133,8 +2137,7 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio return nil, err } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) @@ -2668,9 +2671,7 @@ func (c *makeSetFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } - for i, length := 0, len(args); i < length; i++ { - SetBinFlagOrBinStr(args[i].GetType(), bf.tp) - } + addBinFlag(bf.tp) bf.tp.Flen = c.getFlen(bf.ctx, args) if bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth @@ -3589,8 +3590,7 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express return nil, err } bf.tp.Flen = mysql.MaxBlobWidth - SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - SetBinFlagOrBinStr(args[3].GetType(), bf.tp) + addBinFlag(bf.tp) valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 8a46b32664c24..7cafe7fb1c46b 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -240,8 +240,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"strcmp(c_char, c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"space(c_int_d)", mysql.TypeLongBlob, mysql.DefaultCharset, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"CONCAT(c_binary, c_int_d)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 40, types.UnspecifiedLength}, - {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 40, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 4, types.UnspecifiedLength}, + {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, + {"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 6, types.UnspecifiedLength}, {"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength}, @@ -451,8 +452,9 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"find_in_set(c_set , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, {"find_in_set(c_enum , c_text_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, - {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 65535, types.UnspecifiedLength}, + {"make_set(c_int_d , c_text_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"make_set(c_bigint_d , c_text_d, c_binary)", mysql.TypeMediumBlob, charset.CharsetBin, mysql.BinaryFlag, 65556, types.UnspecifiedLength}, + {"make_set(1 , c_text_d, 0x40)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 65535, types.UnspecifiedLength}, {"quote(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, {"quote(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, @@ -465,6 +467,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"convert(c_text_d using 'binary')", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_varchar, c_int_d, c_int_d, 0x40)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_varchar, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"insert(c_binary, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, From 04bf13a68180d02a7404bd58fb320a46781c6e4a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 13:15:52 +0800 Subject: [PATCH 49/85] planner: not pruning column used by union scan condition (#21640) (#22624) --- planner/core/integration_test.go | 21 +++++++++++++++++++++ planner/core/rule_column_pruning.go | 2 ++ 2 files changed, 23 insertions(+) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 48c96d008499d..26472565e749e 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1864,6 +1864,27 @@ func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { } } +// Test for issue https://github.com/pingcap/tidb/issues/21607. +func (s *testIntegrationSuite) TestConditionColPruneInPhysicalUnionScan(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b int);") + tk.MustExec("begin;") + tk.MustExec("insert into t values (1, 2);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + + tk.MustExec("drop table t;") + tk.MustExec("create table t (a int, b int as (a + 1), c int as (b + 1));") + tk.MustExec("begin;") + tk.MustExec("insert into t (a) values (1);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + tk.MustQuery("select count(*) from t where c = 1 and c in (3);"). + Check(testkit.Rows("0")) +} + func (s *testIntegrationSuite) TestIssue22071(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index ea0edad342c12..ac277dc5206f5 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -217,6 +217,8 @@ func (p *LogicalUnionAll) PruneColumns(parentUsedCols []*expression.Column) erro // PruneColumns implements LogicalPlan interface. func (p *LogicalUnionScan) PruneColumns(parentUsedCols []*expression.Column) error { parentUsedCols = append(parentUsedCols, p.handleCol) + condCols := expression.ExtractColumnsFromExpressions(nil, p.conditions, nil) + parentUsedCols = append(parentUsedCols, condCols...) return p.children[0].PruneColumns(parentUsedCols) } From 06b0f290e31523e103e13ec3ac5b72c056c19eb7 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 15:41:52 +0800 Subject: [PATCH 50/85] Privileges: fix delete privilege check wrongly (#22971) (#23215) --- planner/core/integration_test.go | 14 +++++ planner/core/logical_plan_builder.go | 76 +++++++++++++++---------- privilege/privileges/privileges_test.go | 25 ++++++++ 3 files changed, 85 insertions(+), 30 deletions(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 26472565e749e..83b9032d5f924 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1864,6 +1864,20 @@ func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { } } +func (s *testIntegrationSerialSuite) TestDeleteStmt(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("delete t from t;") + tk.MustExec("delete t from test.t as t;") + tk.MustGetErrCode("delete test.t from test.t as t;", mysql.ErrUnknownTable) + tk.MustExec("delete test.t from t;") + tk.MustExec("create database db1") + tk.MustExec("use db1") + tk.MustExec("create table t(a int)") + tk.MustGetErrCode("delete test.t from t;", mysql.ErrUnknownTable) +} + // Test for issue https://github.com/pingcap/tidb/issues/21607. func (s *testIntegrationSuite) TestConditionColPruneInPhysicalUnionScan(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ca66fd458c536..f608ec87c9037 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4296,53 +4296,50 @@ func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) ( return nil, err } - var tableList []*ast.TableName - tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) - // Collect visitInfo. if delete.Tables != nil { // Delete a, b from a, b, c, d... add a and b. + updatableList := make(map[string]bool) + tbInfoList := make(map[string]*ast.TableName) + collectTableName(delete.TableRefs.TableRefs, &updatableList, &tbInfoList) for _, tn := range delete.Tables.Tables { - foundMatch := false - for _, v := range tableList { - dbName := v.Schema - if dbName.L == "" { - dbName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) - } - if (tn.Schema.L == "" || tn.Schema.L == dbName.L) && tn.Name.L == v.Name.L { - tn.Schema = dbName - tn.DBInfo = v.DBInfo - tn.TableInfo = v.TableInfo - foundMatch = true - break - } + var canUpdate, foundMatch = false, false + name := tn.Name.L + if tn.Schema.L == "" { + canUpdate, foundMatch = updatableList[name] } + if !foundMatch { - var asNameList []string - asNameList = extractTableSourceAsNames(delete.TableRefs.TableRefs, asNameList, false) - for _, asName := range asNameList { - tblName := tn.Name.L - if tn.Schema.L != "" { - tblName = tn.Schema.L + "." + tblName - } - if asName == tblName { - // check sql like: `delete a from (select * from t) as a, t` - return nil, ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") - } + if tn.Schema.L == "" { + name = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB).L + "." + tn.Name.L + } else { + name = tn.Schema.L + "." + tn.Name.L } - // check sql like: `delete b from (select * from t) as a, t` + canUpdate, foundMatch = updatableList[name] + } + // check sql like: `delete b from (select * from t) as a, t` + if !foundMatch { return nil, ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") } + // check sql like: `delete a from (select * from t) as a, t` + if !canUpdate { + return nil, ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") + } + tb := tbInfoList[name] + tn.DBInfo = tb.DBInfo + tn.TableInfo = tb.TableInfo if tn.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", tn.Name.O) } if tn.TableInfo.IsSequence() { return nil, errors.Errorf("delete sequence %s is not supported now.", tn.Name.O) } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "", nil) + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tb.DBInfo.Name.L, tb.Name.L, "", nil) } } else { // Delete from a, b, c, d. + var tableList []*ast.TableName + tableList = extractTableList(delete.TableRefs.TableRefs, tableList, false) for _, v := range tableList { if v.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", v.Name.O) @@ -4423,7 +4420,7 @@ func (p *Delete) cleanTblID2HandleMap( // matchingDeletingTable checks whether this column is from the table which is in the deleting list. func (p *Delete) matchingDeletingTable(names []*ast.TableName, name *types.FieldName) bool { for _, n := range names { - if (name.DBName.L == "" || name.DBName.L == n.Schema.L) && name.TblName.L == n.Name.L { + if (name.DBName.L == "" || name.DBName.L == n.DBInfo.Name.L) && name.TblName.L == n.Name.L { return true } } @@ -5097,6 +5094,25 @@ func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName boo return input } +func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) { + switch x := node.(type) { + case *ast.Join: + collectTableName(x.Left, updatableName, info) + collectTableName(x.Right, updatableName, info) + case *ast.TableSource: + name := x.AsName.L + var canUpdate bool + var s *ast.TableName + if s, canUpdate = x.Source.(*ast.TableName); canUpdate { + if name == "" { + name = s.Schema.L + "." + s.Name.L + } + (*info)[name] = s + } + (*updatableName)[name] = canUpdate + } +} + // extractTableSourceAsNames extracts TableSource.AsNames from node. // if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string { diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 812e39be1b2aa..8c4f5cd8c9eaa 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -156,6 +156,31 @@ func (s *testPrivilegeSuite) TestCheckPointGetDBPrivilege(c *C) { c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) } +func (s *testPrivilegeSuite) TestIssue22946(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, "create database db1;") + mustExec(c, rootSe, "create database db2;") + mustExec(c, rootSe, "use test;") + mustExec(c, rootSe, "create table a(id int);") + mustExec(c, rootSe, "use db1;") + mustExec(c, rootSe, "create table a(id int primary key,name varchar(20));") + mustExec(c, rootSe, "use db2;") + mustExec(c, rootSe, "create table b(id int primary key,address varchar(50));") + mustExec(c, rootSe, "CREATE USER 'delTest'@'localhost';") + mustExec(c, rootSe, "grant all on db1.* to delTest@'localhost';") + mustExec(c, rootSe, "grant all on db2.* to delTest@'localhost';") + mustExec(c, rootSe, "grant select on test.* to delTest@'localhost';") + mustExec(c, rootSe, "flush privileges;") + + se := newSession(c, s.store, s.dbName) + c.Assert(se.Auth(&auth.UserIdentity{Username: "delTest", Hostname: "localhost"}, nil, nil), IsTrue) + _, err := se.ExecuteInternal(context.Background(), `delete from db1.a as A where exists(select 1 from db2.b as B where A.id = B.id);`) + c.Assert(err, IsNil) + mustExec(c, rootSe, "use db1;") + _, err = se.ExecuteInternal(context.Background(), "delete from test.a as A;") + c.Assert(terror.ErrorEqual(err, core.ErrPrivilegeCheckFail), IsTrue) +} + func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { rootSe := newSession(c, s.store, s.dbName) mustExec(c, rootSe, `CREATE USER 'test1'@'localhost';`) From 77c32ba9c882e6dc34087b8b3b7715f498561983 Mon Sep 17 00:00:00 2001 From: Zhi Qi <30543181+LittleFall@users.noreply.github.com> Date: Thu, 6 May 2021 20:07:52 +0800 Subject: [PATCH 51/85] tikv: distinguish server timeout for TiKV and TiFlash (#23700) --- store/tikv/batch_coprocessor.go | 3 ++- store/tikv/region_request.go | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/store/tikv/batch_coprocessor.go b/store/tikv/batch_coprocessor.go index 118f3dffebdb6..dfa1b58b87eea 100644 --- a/store/tikv/batch_coprocessor.go +++ b/store/tikv/batch_coprocessor.go @@ -387,7 +387,8 @@ func (b *batchCopIterator) handleStreamedBatchCopResponse(ctx context.Context, b return nil } - if err1 := bo.Backoff(boTiKVRPC, errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { + // Currently this function is only used in TiFlash batch cop. + if err1 := bo.Backoff(boTiFlashRPC, errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { return errors.Trace(err) } diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index 85cf978c85454..bcb071da5022e 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -199,7 +199,10 @@ func (ss *RegionBatchRequestSender) onSendFail(bo *Backoffer, ctxs []copTaskAndR // When a store is not available, the leader of related region should be elected quickly. // TODO: the number of retry time should be limited:since region may be unavailable // when some unrecoverable disaster happened. - err = bo.Backoff(boTiKVRPC, errors.Errorf("send tikv request error: %v, ctxs: %v, try next peer later", err, ctxs)) + + // Currently this function is only used in TiFlash batch cop. + // TODO: need code refactoring for these functions. + err = bo.Backoff(boTiFlashRPC, errors.Errorf("send tiflash request error: %v, ctxs: %v, try next peer later", err, ctxs)) return errors.Trace(err) } From eb4113c05edb5433cd63420b8b99dfba6ebc19ac Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 6 May 2021 22:05:52 +0800 Subject: [PATCH 52/85] expression: fix approx_percent panic on bit column (#23687) (#23702) --- executor/aggfuncs/builder.go | 6 +++++- expression/integration_test.go | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 45e153aed93e9..3681be119fed1 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -150,9 +150,13 @@ func buildApproxPercentile(sctx sessionctx.Context, aggFuncDesc *aggregation.Agg base := basePercentile{percent: int(percent), baseAggFunc: baseAggFunc{args: aggFuncDesc.Args, ordinal: ordinal}} + evalType := aggFuncDesc.Args[0].GetType().EvalType() + if aggFuncDesc.Args[0].GetType().Tp == mysql.TypeBit { + evalType = types.ETString // same as other aggregate function + } switch aggFuncDesc.Mode { case aggregation.CompleteMode, aggregation.Partial1Mode, aggregation.FinalMode: - switch aggFuncDesc.Args[0].GetType().EvalType() { + switch evalType { case types.ETInt: return &percentileOriginal4Int{base} case types.ETReal: diff --git a/expression/integration_test.go b/expression/integration_test.go index f3eb9e8d578e4..7052673ea4a73 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8087,3 +8087,12 @@ func (s *testIntegrationSerialSuite) TestIssue19116(c *C) { tk.MustQuery("select collation(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("utf8mb4_general_ci")) tk.MustQuery("select coercibility(concat(unix_timestamp(a))) from t;").Check(testkit.Rows("4")) } + +func (s *testIntegrationSuite) TestApproximatePercentile(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a bit(10))") + tk.MustExec("insert into t values(b'1111')") + tk.MustQuery("select approx_percentile(a, 10) from t").Check(testkit.Rows("")) +} From 71b0df94fca120a4f7d8d9ef37aa215028a2c7c1 Mon Sep 17 00:00:00 2001 From: Howie Date: Fri, 7 May 2021 00:05:52 +0800 Subject: [PATCH 53/85] ddl: rollingback add index meets panic leads json unmarshal object error (#23848) --- ddl/db_test.go | 20 ++++++++++++++++++++ ddl/ddl_worker.go | 7 +++++++ 2 files changed, 27 insertions(+) diff --git a/ddl/db_test.go b/ddl/db_test.go index 109d39aff9be5..7f90528d723a6 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -5039,3 +5039,23 @@ func (s *testSerialDBSuite) TestDDLExitWhenCancelMeetPanic(c *C) { c.Assert(job.ErrorCount, Equals, int64(4)) c.Assert(job.Error.Error(), Equals, "[ddl:-1]panic in handling DDL logic and error count beyond the limitation 3, cancelled") } + +// Close issue #23321. +// See https://github.com/pingcap/tidb/issues/23321 +func (s *testSerialDBSuite) TestJsonUnmarshalErrWhenPanicInCancellingPath(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists test_add_index_after_add_col") + tk.MustExec("create table test_add_index_after_add_col(a int, b int not null default '0');") + tk.MustExec("insert into test_add_index_after_add_col values(1, 2),(2,2);") + tk.MustExec("alter table test_add_index_after_add_col add column c int not null default '0';") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockExceedErrorLimit"), IsNil) + }() + + _, err := tk.Exec("alter table test_add_index_after_add_col add unique index cc(c);") + c.Assert(err.Error(), Equals, "[kv:1062]DDL job cancelled by panic in rollingback, error msg: Duplicate entry '0' for key 'cc'") +} diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index 5870624dde5ae..0980a0e159391 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -585,6 +585,13 @@ func chooseLeaseTime(t, max time.Duration) time.Duration { // countForPanic records the error count for DDL job. func (w *worker) countForPanic(job *model.Job) { // If run DDL job panic, just cancel the DDL jobs. + if job.State == model.JobStateRollingback { + job.State = model.JobStateCancelled + msg := fmt.Sprintf("DDL job cancelled by panic in rollingback, error msg: %s", terror.ToSQLError(job.Error).Message) + job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), msg) + logutil.Logger(w.logCtx).Warn(msg) + return + } job.State = model.JobStateCancelling job.ErrorCount++ From b8070e6640a6e5daf795880480b5e0caf037dd86 Mon Sep 17 00:00:00 2001 From: Shenghui Wu <793703860@qq.com> Date: Fri, 7 May 2021 02:05:52 +0800 Subject: [PATCH 54/85] executor: fix resource leak of Shuffle Executor. (#23888) --- executor/shuffle.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/executor/shuffle.go b/executor/shuffle.go index 7de58e7139880..2395819726c9e 100644 --- a/executor/shuffle.go +++ b/executor/shuffle.go @@ -127,6 +127,7 @@ func (e *ShuffleExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *ShuffleExec) Close() error { + var firstErr error if !e.prepared { for _, w := range e.workers { close(w.inputHolderCh) @@ -139,6 +140,9 @@ func (e *ShuffleExec) Close() error { for _, w := range e.workers { for range w.inputCh { } + if err := w.childExec.Close(); err != nil && firstErr == nil { + firstErr = err + } } for range e.outputCh { // workers exit before `e.outputCh` is closed. } @@ -150,12 +154,13 @@ func (e *ShuffleExec) Close() error { e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, runtimeStats) } - err := e.dataSource.Close() - err1 := e.baseExecutor.Close() - if err != nil { - return errors.Trace(err) + if err := e.dataSource.Close(); err != nil && firstErr == nil { + firstErr = err + } + if err := e.baseExecutor.Close(); err != nil && firstErr == nil { + firstErr = err } - return errors.Trace(err1) + return errors.Trace(firstErr) } func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { From 0d4f5de0063af8e5dd908db49971c540f0dbde6d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 7 May 2021 04:25:53 +0800 Subject: [PATCH 55/85] executor: fix `show table status` for the database with upper-cased name (#23896) (#23958) --- executor/show.go | 2 +- executor/show_test.go | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/executor/show.go b/executor/show.go index 8d4095756b137..85b4bbf31e784 100644 --- a/executor/show.go +++ b/executor/show.go @@ -423,7 +423,7 @@ func (e *ShowExec) fetchShowTableStatus() error { data_free, auto_increment, create_time, update_time, check_time, table_collation, IFNULL(checksum,''), create_options, table_comment FROM information_schema.tables - WHERE table_schema=%? ORDER BY table_name`, e.DBName.L) + WHERE lower(table_schema)=%? ORDER BY table_name`, e.DBName.L) if err != nil { return errors.Trace(err) } diff --git a/executor/show_test.go b/executor/show_test.go index 38c8e0ee85837..251b0efba8567 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -19,7 +19,6 @@ import ( "strings" . "github.com/pingcap/check" - "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -491,12 +490,13 @@ func (s *testSuite5) TestShowTableStatus(c *C) { // It's not easy to test the result contents because every time the test runs, "Create_time" changed. tk.MustExec("show table status;") rs, err := tk.Exec("show table status;") - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) c.Assert(rs, NotNil) rows, err := session.GetRows4Test(context.Background(), tk.Se, rs) - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) err = rs.Close() - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) for i := range rows { row := rows[i] @@ -513,10 +513,34 @@ func (s *testSuite5) TestShowTableStatus(c *C) { partition p2 values less than (maxvalue) );`) rs, err = tk.Exec("show table status from test like 'tp';") - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) - c.Assert(errors.ErrorStack(err), Equals, "") + c.Assert(err, IsNil) c.Assert(rows[0].GetString(16), Equals, "partitioned") + + tk.MustExec("create database UPPER_CASE") + tk.MustExec("use UPPER_CASE") + tk.MustExec("create table t (i int)") + rs, err = tk.Exec("show table status") + c.Assert(err, IsNil) + c.Assert(rs, NotNil) + rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) + c.Assert(err, IsNil) + err = rs.Close() + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) + + tk.MustExec("use upper_case") + rs, err = tk.Exec("show table status") + c.Assert(err, IsNil) + c.Assert(rs, NotNil) + rows, err = session.GetRows4Test(context.Background(), tk.Se, rs) + c.Assert(err, IsNil) + err = rs.Close() + c.Assert(err, IsNil) + c.Assert(len(rows), Equals, 1) + + tk.MustExec("drop database UPPER_CASE") } func (s *testSuite5) TestShowSlow(c *C) { From 9bc8258d618d4f3b7b7992f0876f5222e886c8b8 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 7 May 2021 06:45:52 +0800 Subject: [PATCH 56/85] expression: don't propagateColumnEQ joinCondition when nullSensitive (#23989) (#24022) --- expression/constant_propagation.go | 6 +++--- expression/integration_test.go | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index d44fa533283fe..e25b64983316b 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -533,6 +533,9 @@ func (s *propOuterJoinConstSolver) deriveConds(outerCol, innerCol *Column, schem // 'expression(..., outerCol, ...)' does not reference columns outside children schemas of join node. // Derived new expressions must be appended into join condition, not filter condition. func (s *propOuterJoinConstSolver) propagateColumnEQ() { + if s.nullSensitive { + return + } visited := make([]bool, 2*len(s.joinConds)+len(s.filterConds)) s.unionSet = disjointset.NewIntSet(len(s.columns)) var outerCol, innerCol *Column @@ -553,9 +556,6 @@ func (s *propOuterJoinConstSolver) propagateColumnEQ() { // `select *, t1.a in (select t2.b from t t2) from t t1` // rows with t2.b is null would impact whether LeftOuterSemiJoin should output 0 or null if there // is no row satisfying t2.b = t1.a - if s.nullSensitive { - continue - } childCol := s.innerSchema.RetrieveColumn(innerCol) if !mysql.HasNotNullFlag(childCol.RetType.Flag) { notNullExpr := BuildNotNullExpr(s.ctx, childCol) diff --git a/expression/integration_test.go b/expression/integration_test.go index 7052673ea4a73..54e16fe417db0 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8096,3 +8096,17 @@ func (s *testIntegrationSuite) TestApproximatePercentile(c *C) { tk.MustExec("insert into t values(b'1111')") tk.MustQuery("select approx_percentile(a, 10) from t").Check(testkit.Rows("")) } + +func (s *testIntegrationSuite) TestIssue23889(c *C) { + defer s.cleanEnv(c) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists test_decimal,test_t;") + tk.MustExec("create table test_decimal(col_decimal decimal(10,0));") + tk.MustExec("insert into test_decimal values(null),(8);") + tk.MustExec("create table test_t(a int(11), b decimal(32,0));") + tk.MustExec("insert into test_t values(1,4),(2,4),(5,4),(7,4),(9,4);") + + tk.MustQuery("SELECT ( test_decimal . `col_decimal` , test_decimal . `col_decimal` ) IN ( select * from test_t ) as field1 FROM test_decimal;").Check( + testkit.Rows("", "0")) +} From b5f2926d05579b320c7d9afaaf1e8c244c430a87 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 7 May 2021 08:25:52 +0800 Subject: [PATCH 57/85] planner: change descScanFactor to scanFactor when ExpectedCount is small. (#23972) (#24078) --- planner/core/find_best_task.go | 12 ++++++--- planner/core/integration_test.go | 21 +++++++++++++++ .../testdata/integration_serial_suite_in.json | 7 +++++ .../integration_serial_suite_out.json | 27 +++++++++++++++++++ 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 18d1e9b29170f..59136c557b1a6 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -40,6 +40,10 @@ const ( // of a Selection or a JoinCondition, we can use this default value. SelectionFactor = 0.8 distinctFactor = 0.8 + + // If the actual row count is much more than the limit count, the unordered scan may cost much more than keep order. + // So when a limit exists, we don't apply the DescScanFactor. + smallScanThreshold = 10000 ) var aggFuncFactor = map[string]float64{ @@ -1426,8 +1430,8 @@ func (ds *DataSource) getOriginalPhysicalTableScan(prop *property.PhysicalProper cost += rowCount * sessVars.NetworkFactor * rowSize } if isMatchProp { - if prop.Items[0].Desc { - ts.Desc = true + ts.Desc = prop.Items[0].Desc + if prop.Items[0].Desc && prop.ExpectedCnt >= smallScanThreshold { cost = rowCount * rowSize * sessVars.DescScanFactor } ts.KeepOrder = true @@ -1475,8 +1479,8 @@ func (ds *DataSource) getOriginalPhysicalIndexScan(prop *property.PhysicalProper sessVars := ds.ctx.GetSessionVars() cost := rowCount * rowSize * sessVars.ScanFactor if isMatchProp { - if prop.Items[0].Desc { - is.Desc = true + is.Desc = prop.Items[0].Desc + if prop.Items[0].Desc && prop.ExpectedCnt >= smallScanThreshold { cost = rowCount * rowSize * sessVars.DescScanFactor } is.KeepOrder = true diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 83b9032d5f924..98fc2eda92bc2 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1840,6 +1840,27 @@ func (s *testIntegrationSuite) TestIssue22105(c *C) { } } +func (s *testIntegrationSerialSuite) TestLimitIndexLookUpKeepOrder(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a int, b int, c int, d int, index idx(a,b,c));") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSuite) TestReorderSimplifiedOuterJoins(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index 9b39df74067b0..3103392597700 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -96,5 +96,12 @@ "explain select /*+ inl_hash_join(s) */ * from t join s on t.a=s.a and t.b = s.a", "explain select /*+ inl_hash_join(s) */ * from t join s on t.a=s.a and t.a = s.b" ] + }, + { + "name": "TestLimitIndexLookUpKeepOrder", + "cases": [ + "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b,c limit 10", + "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b desc, c desc limit 10" + ] } ] diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index 9dff901e19d3f..529feb4c091c0 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -513,5 +513,32 @@ ] } ] + }, + { + "Name": "TestLimitIndexLookUpKeepOrder", + "Cases": [ + { + "SQL": "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b,c limit 10", + "Plan": [ + "Limit_12 0.00 root offset:0, count:10", + "└─Projection_34 0.00 root test.t.a, test.t.b, test.t.c, test.t.d", + " └─IndexLookUp_33 0.00 root ", + " ├─IndexRangeScan_30(Build) 2.50 cop[tikv] table:t, index:idx(a, b, c) range:(1 2,1 10), keep order:true, stats:pseudo", + " └─Selection_32(Probe) 0.00 cop[tikv] eq(test.t.d, 10)", + " └─TableRowIDScan_31 2.50 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "desc select * from t where a = 1 and b > 2 and b < 10 and d = 10 order by b desc, c desc limit 10", + "Plan": [ + "Limit_12 0.00 root offset:0, count:10", + "└─Projection_34 0.00 root test.t.a, test.t.b, test.t.c, test.t.d", + " └─IndexLookUp_33 0.00 root ", + " ├─IndexRangeScan_30(Build) 2.50 cop[tikv] table:t, index:idx(a, b, c) range:(1 2,1 10), keep order:true, desc, stats:pseudo", + " └─Selection_32(Probe) 0.00 cop[tikv] eq(test.t.d, 10)", + " └─TableRowIDScan_31 2.50 cop[tikv] table:t keep order:false, stats:pseudo" + ] + } + ] } ] From c2cfcd3eb2528ad7b603e9a8fc40f85a8b139651 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 7 May 2021 11:53:52 +0800 Subject: [PATCH 58/85] planner: fix inappropriate null flag of null constants (#23457) (#23474) --- expression/scalar_function.go | 1 + expression/scalar_function_test.go | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 47aa3f4da5037..7b950813e8d3f 100755 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -170,6 +170,7 @@ func typeInferForNull(args []Expression) { for _, arg := range args { if isNull(arg) { *arg.GetType() = *retFieldTp + arg.GetType().Flag &= ^mysql.NotNullFlag // Remove NotNullFlag of NullConst } } } diff --git a/expression/scalar_function_test.go b/expression/scalar_function_test.go index f3349a87c34fb..c5f52e3309532 100755 --- a/expression/scalar_function_test.go +++ b/expression/scalar_function_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" ) func (s *testEvaluatorSuite) TestScalarFunction(c *C) { @@ -47,6 +48,21 @@ func (s *testEvaluatorSuite) TestScalarFunction(c *C) { c.Assert(ok, IsTrue) } +func (s *testEvaluatorSuite) TestIssue23309(c *C) { + a := &Column{ + UniqueID: 1, + RetType: types.NewFieldType(mysql.TypeDouble), + } + a.RetType.Flag |= mysql.NotNullFlag + null := NewNull() + null.RetType = types.NewFieldType(mysql.TypeNull) + sf, _ := newFunction(ast.NE, a, null).(*ScalarFunction) + v, err := sf.GetArgs()[1].Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) + c.Assert(mysql.HasNotNullFlag(sf.GetArgs()[1].GetType().Flag), IsFalse) +} + func (s *testEvaluatorSuite) TestScalarFuncs2Exprs(c *C) { a := &Column{ UniqueID: 1, From 4195c2b1b84f931682f6d04db0b634a9d52c7a60 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 7 May 2021 21:04:41 +0800 Subject: [PATCH 59/85] planner: fix set not null flag for outer join (#23727) (#23756) --- executor/executor_test.go | 9 +++++---- executor/index_lookup_merge_join_test.go | 1 - planner/core/rule_column_pruning.go | 19 +------------------ 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index 2240bdc280bc0..d12591ec9e7cf 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6715,17 +6715,18 @@ func (s *testSuiteP1) TestIssue22941(c *C) { PRIMARY KEY (mid), KEY ind_bm_parent (ParentId,mid) )`) - + // mp should have more columns than m tk.MustExec(`CREATE TABLE mp ( mpid bigint(20) unsigned NOT NULL DEFAULT '0', mid varchar(50) DEFAULT NULL COMMENT '模块主键', + sid int, PRIMARY KEY (mpid) );`) - tk.MustExec(`insert into mp values("1","1");`) + tk.MustExec(`insert into mp values("1","1","0");`) tk.MustExec(`insert into m values("0", "0");`) - rs := tk.MustQuery(`SELECT ( SELECT COUNT(1) FROM m WHERE ParentId = c.mid ) expand, bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) - rs.Check(testkit.Rows("1 1 0")) + rs := tk.MustQuery(`SELECT ( SELECT COUNT(1) FROM m WHERE ParentId = c.mid ) expand, bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL, sid FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) + rs.Check(testkit.Rows("1 1 0 ")) rs = tk.MustQuery(`SELECT bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) rs.Check(testkit.Rows(" 1 0")) diff --git a/executor/index_lookup_merge_join_test.go b/executor/index_lookup_merge_join_test.go index 93d1d9799d58b..7958867ca4f5c 100644 --- a/executor/index_lookup_merge_join_test.go +++ b/executor/index_lookup_merge_join_test.go @@ -9,7 +9,6 @@ import ( ) // TODO: reopen the index merge join in future. - //func (s *testSuite9) TestIndexLookupMergeJoinHang(c *C) { // c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/IndexMergeJoinMockOOM", `return(true)`), IsNil) // defer func() { diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index ac277dc5206f5..7bc633bdfdb3d 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -303,24 +303,7 @@ func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (left } func (p *LogicalJoin) mergeSchema() { - lChild := p.children[0] - rChild := p.children[1] - if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { - p.schema = lChild.Schema().Clone() - } else if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - joinCol := p.schema.Columns[len(p.schema.Columns)-1] - p.schema = lChild.Schema().Clone() - p.schema.Append(joinCol) - } else { - p.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema()) - switch p.JoinType { - case LeftOuterJoin: - resetNotNullFlag(p.schema, p.children[1].Schema().Len(), p.schema.Len()) - case RightOuterJoin: - resetNotNullFlag(p.schema, 0, p.children[0].Schema().Len()) - default: - } - } + p.schema = buildLogicalJoinSchema(p.JoinType, p) } // PruneColumns implements LogicalPlan interface. From b8fb03cac4e7e432f357a0dc9473a5a952583ed6 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sat, 8 May 2021 13:44:41 +0800 Subject: [PATCH 60/85] executor: fix batchget overflow lock panic (#23774) (#23778) --- executor/batch_point_get.go | 3 +++ session/pessimistic_test.go | 1 + 2 files changed, 4 insertions(+) diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index d6c1257651697..71c7f2f574fce 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -180,6 +180,9 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if err1 != nil && !kv.ErrNotExist.Equal(err1) { return err1 } + if idxKey == nil { + continue + } s := hack.String(idxKey) if _, found := dedup[s]; found { continue diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index 5d05ac08042a1..536b7feaab31d 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -304,6 +304,7 @@ func (s *testPessimisticSuite) TestPointGetOverflow(c *C) { tk.MustExec("create table t(k tinyint, v int, unique key(k))") tk.MustExec("begin pessimistic") tk.MustExec("update t set v = 100 where k = -200;") + tk.MustExec("update t set v = 100 where k in (-200, -400);") } func (s *testPessimisticSuite) TestPointGetKeyLock(c *C) { From 083c2997c921294b3de8e08671fcf50e387e0fc4 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sat, 8 May 2021 14:30:41 +0800 Subject: [PATCH 61/85] statistics: fix auto analyze log information incomplete (#23522) (#23543) --- statistics/handle/update.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 0ee940b559f05..7c56cef58d2d1 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -678,7 +678,7 @@ func NeedAnalyzeTable(tbl *statistics.Table, limit time.Duration, autoAnalyzeRat if !analyzed { t := time.Unix(0, oracle.ExtractPhysical(tbl.Version)*int64(time.Millisecond)) dur := time.Since(t) - return dur >= limit, fmt.Sprintf("table unanalyzed, time since last updated %vs", dur) + return dur >= limit, fmt.Sprintf("table unanalyzed, time since last updated %v", dur) } // Auto analyze is disabled. if autoAnalyzeRatio == 0 { @@ -773,14 +773,24 @@ func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics return false } if needAnalyze, reason := NeedAnalyzeTable(statsTbl, 20*h.Lease(), ratio, start, end, time.Now()); needAnalyze { - logutil.BgLogger().Info("[stats] auto analyze triggered", zap.String("sql", sql), zap.String("reason", reason)) + escaped, err := sqlexec.EscapeSQL(sql, params...) + if err != nil { + return false + } + logutil.BgLogger().Info("[stats] auto analyze triggered", zap.String("sql", escaped), zap.String("reason", reason)) h.execAutoAnalyze(sql, params...) return true } for _, idx := range tblInfo.Indices { if _, ok := statsTbl.Indices[idx.ID]; !ok && idx.State == model.StatePublic { - logutil.BgLogger().Info("[stats] auto analyze for unanalyzed", zap.String("sql", sql)) - h.execAutoAnalyze(sql+" index %n", append(params, idx.Name.O)...) + sqlWithIdx := sql + " index %n" + paramsWithIdx := append(params, idx.Name.O) + escaped, err := sqlexec.EscapeSQL(sqlWithIdx, paramsWithIdx...) + if err != nil { + return false + } + logutil.BgLogger().Info("[stats] auto analyze for unanalyzed", zap.String("sql", escaped)) + h.execAutoAnalyze(sqlWithIdx, paramsWithIdx...) return true } } From e431c0692d7a10110845088cdf0412a114d227e5 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sat, 8 May 2021 17:00:41 +0800 Subject: [PATCH 62/85] executor: fix update panic on join having statement (#23554) (#23575) --- executor/update_test.go | 9 +++++++++ planner/core/rule_eliminate_projection.go | 3 +++ 2 files changed, 12 insertions(+) diff --git a/executor/update_test.go b/executor/update_test.go index 1e6559dcf9ccd..cd28590a97d4a 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -259,3 +259,12 @@ func (s *testPointGetSuite) TestIssue21447(c *C) { tk1.MustQuery("select * from t1 where id in (1, 2) for update").Check(testkit.Rows("1 xyz")) tk1.MustExec("commit") } + +func (s *testPointGetSuite) TestIssue23553(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`drop table if exists tt`) + tk.MustExec(`create table tt (m0 varchar(64), status tinyint not null)`) + tk.MustExec(`insert into tt values('1',0),('1',0),('1',0)`) + tk.MustExec(`update tt a inner join (select m0 from tt where status!=1 group by m0 having count(*)>1) b on a.m0=b.m0 set a.status=1`) +} diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 404c5e01dd77c..51ab76a34dc7c 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -87,6 +87,9 @@ func doPhysicalProjectionElimination(p PhysicalPlan) PhysicalPlan { return p } child := p.Children()[0] + if childProj, ok := child.(*PhysicalProjection); ok { + childProj.SetSchema(p.Schema()) + } return child } From f159b6231faf000eea7079750ad36f9fdfdfa0a0 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sat, 8 May 2021 18:58:41 +0800 Subject: [PATCH 63/85] executor: fix index join on prefix column index (#23678) (#23691) --- executor/index_lookup_join.go | 4 ++-- executor/index_lookup_join_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index 52db90a83d587..d6e335c030522 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -553,8 +553,8 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*indexJoi if iw.hasPrefixCol { for i := range iw.outerCtx.keyCols { // If it's a prefix column. Try to fix it. - if iw.colLens[i] != types.UnspecifiedLength { - ranger.CutDatumByPrefixLen(&dLookUpKey[i], iw.colLens[i], iw.rowTypes[iw.keyCols[i]]) + if iw.colLens[iw.keyCols[i]] != types.UnspecifiedLength { + ranger.CutDatumByPrefixLen(&dLookUpKey[i], iw.colLens[iw.keyCols[i]], iw.rowTypes[iw.keyCols[i]]) } } // dLookUpKey is sorted and deduplicated at sortAndDedupLookUpContents. diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index 0e3e11316848d..48b4177b45fb6 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -252,3 +252,28 @@ func (s *testSuite5) TestIndexJoinEnumSetIssue19233(c *C) { } } } + +func (s *testSuite5) TestIssue23653(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (c_int int, c_str varchar(40), primary key(c_str), unique key(c_int), unique key(c_str))") + tk.MustExec("create table t2 (c_int int, c_str varchar(40), primary key(c_int, c_str(4)), key(c_int), unique key(c_str))") + tk.MustExec("insert into t1 values (1, 'cool buck'), (2, 'reverent keller')") + tk.MustExec("insert into t2 select * from t1") + tk.MustQuery("select /*+ inl_join(t2) */ * from t1, t2 where t1.c_str = t2.c_str and t1.c_int = t2.c_int and t1.c_int = 2").Check(testkit.Rows( + "2 reverent keller 2 reverent keller")) +} + +func (s *testSuite5) TestIssue23656(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (c_int int, c_str varchar(40), primary key(c_int, c_str(4)))") + tk.MustExec("create table t2 like t1") + tk.MustExec("insert into t1 values (1, 'clever jang'), (2, 'blissful aryabhata')") + tk.MustExec("insert into t2 select * from t1") + tk.MustQuery("select /*+ inl_join(t2) */ * from t1 join t2 on t1.c_str = t2.c_str where t1.c_int = t2.c_int;").Check(testkit.Rows( + "1 clever jang 1 clever jang", + "2 blissful aryabhata 2 blissful aryabhata")) +} From 65ee2b9c64fe1e8d28ff9fb78087711093223a00 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 10:22:42 +0800 Subject: [PATCH 64/85] types: fix collation for binary literal (#23591) (#23598) --- expression/integration_test.go | 12 ++++++++++++ types/datum.go | 1 + 2 files changed, 13 insertions(+) diff --git a/expression/integration_test.go b/expression/integration_test.go index 54e16fe417db0..d71e5eba35fd8 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8110,3 +8110,15 @@ func (s *testIntegrationSuite) TestIssue23889(c *C) { tk.MustQuery("SELECT ( test_decimal . `col_decimal` , test_decimal . `col_decimal` ) IN ( select * from test_t ) as field1 FROM test_decimal;").Check( testkit.Rows("", "0")) } + +func (s *testIntegrationSerialSuite) TestCollationForBinaryLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (`COL1` tinyblob NOT NULL, `COL2` binary(1) NOT NULL, `COL3` bigint(11) NOT NULL, PRIMARY KEY (`COL1`(5),`COL2`,`COL3`) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("insert into t values(0x1E,0xEC,6966939640596047133);") + tk.MustQuery("select * from t where col1 not in (0x1B,0x20) order by col1").Check(testkit.Rows("\x1e \xec 6966939640596047133")) + tk.MustExec("drop table t") +} diff --git a/types/datum.go b/types/datum.go index 715ebcb4d315f..7512ff60f0d7f 100644 --- a/types/datum.go +++ b/types/datum.go @@ -252,6 +252,7 @@ func (d *Datum) GetMysqlBit() BinaryLiteral { func (d *Datum) SetBinaryLiteral(b BinaryLiteral) { d.k = KindBinaryLiteral d.b = b + d.collation = charset.CollationBin } // SetMysqlBit sets MysqlBit value From ff9af52f97ef09ecc791f07fe1112a3fa1c8f493 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 10:48:43 +0800 Subject: [PATCH 65/85] ddl: fix the covert job to rollingback job (#23903) (#24445) --- ddl/db_test.go | 2 +- ddl/rollingback.go | 71 +++++++++++++++++++--------- ddl/rollingback_test.go | 101 ++++++++++++++++++++++++++++++++++++++++ ddl/serial_test.go | 4 +- 4 files changed, 152 insertions(+), 26 deletions(-) create mode 100644 ddl/rollingback_test.go diff --git a/ddl/db_test.go b/ddl/db_test.go index 7f90528d723a6..c8afb966a8714 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -4733,7 +4733,7 @@ func (s *testSerialDBSuite) TestAddIndexFailOnCaseWhenCanExit(c *C) { tk.MustExec("insert into t values(1, 1)") _, err := tk.Exec("alter table t add index idx(b)") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: job.ErrCount:512, mock unknown type: ast.whenClause.") tk.MustExec("drop table if exists t") } diff --git a/ddl/rollingback.go b/ddl/rollingback.go index adb0355cc7395..4e205ad7cbfdf 100644 --- a/ddl/rollingback.go +++ b/ddl/rollingback.go @@ -14,11 +14,16 @@ package ddl import ( + "fmt" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -37,8 +42,11 @@ func updateColsNull2NotNull(tblInfo *model.TableInfo, indexInfo *model.IndexInfo } func convertAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, indexInfo *model.IndexInfo, err error) (int64, error) { - job.State = model.JobStateRollingback - + failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) + } + }) if indexInfo.Primary { nullCols, err := getNullColInfos(tblInfo, indexInfo) if err != nil { @@ -65,7 +73,7 @@ func convertAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, tblInfo *model.T if err1 != nil { return ver, errors.Trace(err1) } - + job.State = model.JobStateRollingback return ver, errors.Trace(err) } @@ -99,7 +107,6 @@ func convertNotStartAddIdxJob2RollbackJob(t *meta.Meta, job *model.Job, occuredE } func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { - job.State = model.JobStateRollingback tblInfo, columnInfo, col, _, _, err := checkAddColumn(t, job) if err != nil { return ver, errors.Trace(err) @@ -118,11 +125,13 @@ func rollingbackAddColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { if err != nil { return ver, errors.Trace(err) } + + job.State = model.JobStateRollingback return ver, errCancelledDDLJob } func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, colInfo, err := checkDropColumn(t, job) + _, colInfo, err := checkDropColumn(t, job) if err != nil { return ver, errors.Trace(err) } @@ -130,7 +139,6 @@ func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) // StatePublic means when the job is not running yet. if colInfo.State == model.StatePublic { job.State = model.JobStateCancelled - job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) return ver, errCancelledDDLJob } // In the state of drop column `write only -> delete only -> reorganization`, @@ -140,12 +148,11 @@ func rollingbackDropColumn(t *meta.Meta, job *model.Job) (ver int64, err error) } func rollingbackDropIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, indexInfo, err := checkDropIndex(t, job) + _, indexInfo, err := checkDropIndex(t, job) if err != nil { return ver, errors.Trace(err) } - originalState := indexInfo.State switch indexInfo.State { case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: // We can not rollback now, so just continue to drop index. @@ -153,20 +160,11 @@ func rollingbackDropIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { job.State = model.JobStateRunning return ver, nil case model.StatePublic: - job.State = model.JobStateRollbackDone - indexInfo.State = model.StatePublic + job.State = model.JobStateCancelled + return ver, errCancelledDDLJob default: return ver, ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo.State) } - - job.SchemaState = indexInfo.State - job.Args = []interface{}{indexInfo.Name} - ver, err = updateVersionAndTableInfo(t, job, tblInfo, originalState != indexInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateRollbackDone, model.StatePublic, ver, tblInfo) - return ver, errCancelledDDLJob } func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { @@ -184,7 +182,6 @@ func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isP } func convertAddTablePartitionJob2RollbackJob(t *meta.Meta, job *model.Job, otherwiseErr error, tblInfo *model.TableInfo) (ver int64, err error) { - job.State = model.JobStateRollingback addingDefinitions := tblInfo.Partition.AddingDefinitions partNames := make([]string, 0, len(addingDefinitions)) for _, pd := range addingDefinitions { @@ -195,6 +192,7 @@ func convertAddTablePartitionJob2RollbackJob(t *meta.Meta, job *model.Job, other if err != nil { return ver, errors.Trace(err) } + job.State = model.JobStateRollingback return ver, errors.Trace(otherwiseErr) } @@ -321,14 +319,41 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) } if err != nil { + if job.Error == nil { + job.Error = toTError(err) + } + job.ErrorCount++ + + if errCancelledDDLJob.Equal(err) { + // The job is normally cancelled. + if !job.Error.Equal(errCancelledDDLJob) { + job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), + fmt.Sprintf("DDL job rollback, error msg: %s", terror.ToSQLError(job.Error).Message)) + } + } else { + // A job canceling meet other error. + // + // Once `convertJob2RollbackJob` meets an error, the job state can't be set as `JobStateRollingback` since + // job state and args may not be correctly overwritten. The job will be fetched to run with the cancelling + // state again. So we should check the error count here. + if err1 := loadDDLVars(w); err1 != nil { + logutil.Logger(w.logCtx).Error("[ddl] load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + if job.ErrorCount > errorCount { + logutil.Logger(w.logCtx).Warn("[ddl] rollback DDL job error count exceed the limit, cancelled it now", zap.Int64("jobID", job.ID), zap.Int64("errorCountLimit", errorCount)) + job.Error = toTError(errors.Errorf("rollback DDL job error count exceed the limit %d, cancelled it now", errorCount)) + job.State = model.JobStateCancelled + } + } + if job.State != model.JobStateRollingback && job.State != model.JobStateCancelled { logutil.Logger(w.logCtx).Error("[ddl] run DDL job failed", zap.String("job", job.String()), zap.Error(err)) } else { logutil.Logger(w.logCtx).Info("[ddl] the DDL job is cancelled normally", zap.String("job", job.String()), zap.Error(err)) + // If job is cancelled, we shouldn't return an error. + return ver, nil } - - job.Error = toTError(err) - job.ErrorCount++ } return } diff --git a/ddl/rollingback_test.go b/ddl/rollingback_test.go new file mode 100644 index 0000000000000..c7a188ecbd012 --- /dev/null +++ b/ddl/rollingback_test.go @@ -0,0 +1,101 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl_test + +import ( + "context" + "strconv" + + . "github.com/pingcap/check" + errors2 "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/testkit" +) + +var _ = SerialSuites(&testRollingBackSuite{&testDBSuite{}}) + +type testRollingBackSuite struct{ *testDBSuite } + +// TestCancelJobMeetError is used to test canceling ddl job failure when convert ddl job to a rollingback job. +func (s *testRollingBackSuite) TestCancelAddIndexJobError(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk1 := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk1.MustExec("use test") + + tk.MustExec("create table t_cancel_add_index (a int)") + tk.MustExec("insert into t_cancel_add_index values(1),(2),(3)") + tk.MustExec("set @@global.tidb_ddl_error_count_limit=3") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockConvertAddIdxJob2RollbackJobError", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockConvertAddIdxJob2RollbackJobError"), IsNil) + }() + + tbl := testGetTableByName(c, tk.Se, "test", "t_cancel_add_index") + c.Assert(tbl, NotNil) + + d := s.dom.DDL() + hook := &ddl.TestDDLCallback{Do: s.dom} + var ( + checkErr error + jobID int64 + res sqlexec.RecordSet + ) + hook.OnJobUpdatedExported = func(job *model.Job) { + if job.TableID != tbl.Meta().ID { + return + } + if job.Type != model.ActionAddIndex { + return + } + if job.SchemaState == model.StateDeleteOnly { + jobID = job.ID + res, checkErr = tk1.Exec("admin cancel ddl jobs " + strconv.Itoa(int(job.ID))) + // drain the result set here, otherwise the cancel action won't take effect immediately. + chk := res.NewChunk() + if err := res.Next(context.Background(), chk); err != nil { + checkErr = err + return + } + if err := res.Close(); err != nil { + checkErr = err + } + } + } + d.(ddl.DDLForTest).SetHook(hook) + + // This will hang on stateDeleteOnly, and the job will be canceled. + _, err := tk.Exec("alter table t_cancel_add_index add index idx(a)") + c.Assert(err, NotNil) + c.Assert(checkErr, IsNil) + c.Assert(err.Error(), Equals, "[ddl:-1]rollback DDL job error count exceed the limit 3, cancelled it now") + + // Verification of the history job state. + var job *model.Job + err = kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { + t := meta.NewMeta(txn) + var err1 error + job, err1 = t.GetHistoryDDLJob(jobID) + return errors2.Trace(err1) + }) + c.Assert(err, IsNil) + c.Assert(job.ErrorCount, Equals, int64(4)) + c.Assert(job.Error.Error(), Equals, "[ddl:-1]rollback DDL job error count exceed the limit 3, cancelled it now") +} diff --git a/ddl/serial_test.go b/ddl/serial_test.go index 684e795482c4b..ea2e94f347873 100644 --- a/ddl/serial_test.go +++ b/ddl/serial_test.go @@ -745,7 +745,7 @@ func (s *testSerialSuite) TestCancelJobByErrorCountLimit(c *C) { _, err = tk.Exec("create table t (a int)") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: mock do job error") } func (s *testSerialSuite) TestTruncateTableUpdateSchemaVersionErr(c *C) { @@ -763,7 +763,7 @@ func (s *testSerialSuite) TestTruncateTableUpdateSchemaVersionErr(c *C) { tk.MustExec("create table t (a int)") _, err = tk.Exec("truncate table t") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[ddl:8214]Cancelled DDL job") + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: mock update version error") // Disable fail point. c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockTruncateTableUpdateVersionError"), IsNil) tk.MustExec("truncate table t") From 47deb7affe5824b0cc45d13b3831135a93b57a30 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 12:44:43 +0800 Subject: [PATCH 66/85] plugin: fix audit plugin will cause tidb panic (#23803) (#23819) --- server/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/server.go b/server/server.go index 26b6f310848a6..8f34ef9ef61a7 100644 --- a/server/server.go +++ b/server/server.go @@ -449,6 +449,10 @@ func (s *Server) onConn(conn *clientConn) { conn.Run(ctx) err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + // Audit plugin may be disabled before a conn is created, leading no connectionInfo in sessionVars. + if sessionVars.ConnectionInfo == nil { + sessionVars.ConnectionInfo = conn.connectInfo() + } authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) From 842e1803b639604e23c1cca62314a02764492971 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 13:14:43 +0800 Subject: [PATCH 67/85] executor: group_concat aggr panic when session.group_concat_max_len is small (#23131) (#23257) --- executor/aggfuncs/func_group_concat.go | 31 +++++++++++++++++--------- executor/aggregate_test.go | 17 ++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index d13bc754b0842..c2d3b745d12fc 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -236,6 +236,11 @@ type topNRows struct { currSize uint64 limitSize uint64 sepSize uint64 + // If sep is truncated, we need to append part of sep to result. + // In the following example, session.group_concat_max_len is 10 and sep is '---'. + // ('---', 'ccc') should be poped from heap, so '-' should be appended to result. + // eg: 'aaa---bbb---ccc' -> 'aaa---bbb-' + isSepTruncated bool } func (h topNRows) Len() int { @@ -296,6 +301,7 @@ func (h *topNRows) tryToAdd(row sortRow) (truncated bool) { } else { h.currSize -= uint64(h.rows[0].buffer.Len()) + h.sepSize heap.Pop(h) + h.isSepTruncated = true } } return true @@ -316,10 +322,11 @@ func (h *topNRows) concat(sep string, truncated bool) string { } buffer.Write(row.buffer.Bytes()) } - if truncated && uint64(buffer.Len()) < h.limitSize { - // append the last separator, because the last separator may be truncated in tryToAdd. + if h.isSepTruncated { buffer.WriteString(sep) - buffer.Truncate(int(h.limitSize)) + if uint64(buffer.Len()) > h.limitSize { + buffer.Truncate(int(h.limitSize)) + } } return buffer.String() } @@ -349,10 +356,11 @@ func (e *groupConcatOrder) AllocPartialResult() PartialResult { } p := &partialResult4GroupConcatOrder{ topN: &topNRows{ - desc: desc, - currSize: 0, - limitSize: e.maxLen, - sepSize: uint64(len(e.sep)), + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + isSepTruncated: false, }, } return PartialResult(p) @@ -449,10 +457,11 @@ func (e *groupConcatDistinctOrder) AllocPartialResult() PartialResult { } p := &partialResult4GroupConcatOrderDistinct{ topN: &topNRows{ - desc: desc, - currSize: 0, - limitSize: e.maxLen, - sepSize: uint64(len(e.sep)), + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + isSepTruncated: false, }, valSet: set.NewStringSet(), } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index ad48bc68ae314..842688795e44b 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -630,6 +630,23 @@ func (s *testSuiteAgg) TestGroupConcatAggr(c *C) { // issue #9920 tk.MustQuery("select group_concat(123, null)").Check(testkit.Rows("")) + + // issue #23129 + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(cid int, sname varchar(100));") + tk.MustExec("insert into t1 values(1, 'Bob'), (1, 'Alice');") + tk.MustExec("insert into t1 values(3, 'Ace');") + tk.MustExec("set @@group_concat_max_len=5;") + rows := tk.MustQuery("select group_concat(sname order by sname) from t1 group by cid;") + rows.Check(testkit.Rows("Alice", "Ace")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 varchar(10));") + tk.MustExec("insert into t1 values('0123456789');") + tk.MustExec("insert into t1 values('12345');") + tk.MustExec("set @@group_concat_max_len=8;") + rows = tk.MustQuery("select group_concat(c1 order by c1) from t1 group by c1;") + rows.Check(testkit.Rows("01234567", "12345")) } func (s *testSuiteAgg) TestSelectDistinct(c *C) { From 1a601d8b2f5ea2653614b6198578722264cd58db Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 14:32:44 +0800 Subject: [PATCH 68/85] expression: fix unexpected constant fold when year compare string (#23281) (#23335) --- expression/builtin_compare.go | 8 ++++---- expression/integration_test.go | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 72b71de7ac5ca..957d9f366c346 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1305,15 +1305,15 @@ func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldT // We try to convert the string constant to double. // If the double result equals the int result, we can return the int result; // otherwise, the compare function will be false. + // **Notice** + // we can not compare double result to int result directly, because year type will change its value, like + // 2 to 2002, here we just check whether double value equal int(double value). We can assert the int(string) var doubleDatum types.Datum doubleDatum, err = dt.ConvertTo(sc, types.NewFieldType(mysql.TypeDouble)) if err != nil { return con, false } - if c, err = doubleDatum.CompareDatum(sc, &intDatum); err != nil { - return con, false - } - if c != 0 { + if doubleDatum.GetFloat64() != math.Trunc(doubleDatum.GetFloat64()) { return con, true } return &Constant{ diff --git a/expression/integration_test.go b/expression/integration_test.go index d71e5eba35fd8..dae8046c33b7f 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8047,6 +8047,16 @@ func (s *testSuite2) TestIssue12205(c *C) { testkit.Rows("Warning 1292 Truncated incorrect time value: '18446744072635875000'")) } +func (s *testIntegrationSuite) Test23262(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a year)") + tk.MustExec("insert into t values(2002)") + tk.MustQuery("select * from t where a=2").Check(testkit.Rows("2002")) + tk.MustQuery("select * from t where a='2'").Check(testkit.Rows("2002")) +} + func (s *testIntegrationSuite) TestIssue11333(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) From aec654c9615445f2232a9341cb60ec272e36670f Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 16:58:44 +0800 Subject: [PATCH 69/85] executor, expression: fix the incorrect result of AVG function (#23285) (#23368) --- executor/aggregate_test.go | 31 +++++++++++++++++++ expression/aggregation/base_func.go | 7 ++++- .../transformation_rules_suite_out.json | 4 +-- .../testdata/plan_suite_unexported_out.json | 24 +++++++------- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 842688795e44b..fdf643a4a0782 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1172,3 +1172,34 @@ func (s *testSuiteAgg) TestIssue19426(c *C) { tk.MustQuery("select a, b, sum(case when a < 1000 then b else 0.0 end) over (order by a) from t"). Check(testkit.Rows("1 11 11.0", "2 22 33.0", "3 33 66.0", "4 44 110.0")) } + +func (s *testSuiteAgg) TestIssue23277(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + + tk.MustExec("create table t(a tinyint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a smallint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a mediumint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a int(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") + + tk.MustExec("create table t(a bigint(1));") + tk.MustExec("insert into t values (-120), (127);") + tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) + tk.MustExec("drop table t;") +} diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 1c54dec75b6e5..1a1a7543ef29a 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -212,7 +212,12 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { // Because child returns integer or decimal type. func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { switch a.Args[0].GetType().Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeNewDecimal: + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + a.RetTp.Decimal = types.DivFracIncr + flen, _ := mysql.GetDefaultFieldLengthAndDecimal(a.Args[0].GetType().Tp) + a.RetTp.Flen = flen + types.DivFracIncr + case mysql.TypeYear, mysql.TypeNewDecimal: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) if a.Args[0].GetType().Decimal < 0 { a.RetTp.Decimal = mysql.MaxDecimalScale diff --git a/planner/cascades/testdata/transformation_rules_suite_out.json b/planner/cascades/testdata/transformation_rules_suite_out.json index 36fe4b319c1c8..645fe6ac0ea97 100644 --- a/planner/cascades/testdata/transformation_rules_suite_out.json +++ b/planner/cascades/testdata/transformation_rules_suite_out.json @@ -2324,7 +2324,7 @@ "Group#2 Schema:[Column#13,Column#14,Column#15,test.t.b,Column#16,Column#17,Column#18,Column#19,Column#20,Column#14,Column#13]", " Selection_4 input:[Group#3], ge(Column#13, 0), ge(Column#14, 0)", "Group#3 Schema:[Column#13,Column#14,Column#15,test.t.b,Column#16,Column#17,Column#18,Column#19,Column#20,Column#14,Column#13]", - " Projection_8 input:[Group#4], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(65,30) BINARY)->Column#15, test.t.b, cast(test.t.b, int(11))->Column#16, cast(test.t.b, int(11))->Column#17, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#20, cast(test.t.b, decimal(65,0) BINARY)->Column#14, 1->Column#13", + " Projection_8 input:[Group#4], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(15,4) BINARY)->Column#15, test.t.b, cast(test.t.b, int(11))->Column#16, cast(test.t.b, int(11))->Column#17, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.b, bigint(21) UNSIGNED BINARY), 0)->Column#20, cast(test.t.b, decimal(65,0) BINARY)->Column#14, 1->Column#13", "Group#4 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", " DataSource_1 table:t" ] @@ -2333,7 +2333,7 @@ "SQL": "select count(b), sum(b), avg(b), f, max(c), min(c), bit_and(c), bit_or(d), bit_xor(g) from t group by a", "Result": [ "Group#0 Schema:[Column#13,Column#14,Column#15,test.t.f,Column#16,Column#17,Column#18,Column#19,Column#20]", - " Projection_5 input:[Group#1], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(65,30) BINARY)->Column#15, test.t.f, cast(test.t.c, int(11))->Column#16, cast(test.t.c, int(11))->Column#17, ifnull(cast(test.t.c, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.d, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.g, bigint(21) UNSIGNED BINARY), 0)->Column#20", + " Projection_5 input:[Group#1], 1->Column#13, cast(test.t.b, decimal(65,0) BINARY)->Column#14, cast(test.t.b, decimal(15,4) BINARY)->Column#15, test.t.f, cast(test.t.c, int(11))->Column#16, cast(test.t.c, int(11))->Column#17, ifnull(cast(test.t.c, bigint(21) UNSIGNED BINARY), 18446744073709551615)->Column#18, ifnull(cast(test.t.d, bigint(21) UNSIGNED BINARY), 0)->Column#19, ifnull(cast(test.t.g, bigint(21) UNSIGNED BINARY), 0)->Column#20", "Group#1 Schema:[test.t.a,test.t.b,test.t.c,test.t.d,test.t.f,test.t.g], UniqueKey:[test.t.f,test.t.f,test.t.g,test.t.a]", " DataSource_1 table:t" ] diff --git a/planner/core/testdata/plan_suite_unexported_out.json b/planner/core/testdata/plan_suite_unexported_out.json index fd7cc4aa95ae4..bc5dbe6d230d3 100644 --- a/planner/core/testdata/plan_suite_unexported_out.json +++ b/planner/core/testdata/plan_suite_unexported_out.json @@ -183,11 +183,11 @@ { "Name": "TestWindowFunction", "Cases": [ - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Projection", "IndexReader(Index(t.f)[[NULL,+inf]])->Projection->Sort->Window(avg(cast(Column#16, decimal(24,4) BINARY))->Column#17 over(partition by Column#15))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "[planner:1054]Unknown column 'z' in 'field list'", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#14 over())->Sort->Projection", "IndexReader(Index(t.f)[[NULL,+inf]]->StreamAgg)->StreamAgg->Window(sum(Column#13)->Column#15 over())->Projection", @@ -206,7 +206,7 @@ "IndexReader(Index(t.f)[[NULL,+inf]])->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(rows between 1 preceding and 1 following))->Projection", "[planner:3583]Window '' cannot inherit 'w' since both contain an ORDER BY clause.", "[planner:3591]Window 'w1' is defined twice.", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(partition by test.t.a))->Sort->Projection", "[planner:1235]This version of TiDB doesn't yet support 'GROUPS'", "[planner:3584]Window '': frame start cannot be UNBOUNDED FOLLOWING.", @@ -227,7 +227,7 @@ "[planner:1210]Incorrect arguments to nth_value", "[planner:1210]Incorrect arguments to ntile", "IndexReader(Index(t.f)[[NULL,+inf]])->Window(ntile()->Column#14 over())->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Projection", "TableReader(Table(t))->Window(nth_value(test.t.i_date, 1)->Column#14 over())->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#15, sum(cast(test.t.c, decimal(65,0) BINARY))->Column#16 over(order by test.t.a asc range between unbounded preceding and current row))->Projection", "[planner:3593]You cannot use the window function 'sum' in this context.'", @@ -256,11 +256,11 @@ { "Name": "TestWindowParallelFunction", "Cases": [ - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", "IndexReader(Index(t.f)[[NULL,+inf]])->Projection->Sort->Window(avg(cast(Column#16, decimal(24,4) BINARY))->Column#17 over(partition by Column#15))->Partition(execution info: concurrency:4, data source:Projection_8)->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(order by test.t.a asc, test.t.b desc range between unbounded preceding and current row))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "[planner:1054]Unknown column 'z' in 'field list'", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#14 over())->Sort->Projection", "IndexReader(Index(t.f)[[NULL,+inf]]->StreamAgg)->StreamAgg->Window(sum(Column#13)->Column#15 over())->Projection", @@ -279,7 +279,7 @@ "IndexReader(Index(t.f)[[NULL,+inf]])->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(rows between 1 preceding and 1 following))->Projection", "[planner:3583]Window '' cannot inherit 'w' since both contain an ORDER BY clause.", "[planner:3591]Window 'w1' is defined twice.", - "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.a))->Projection", + "TableReader(Table(t))->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.a))->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.a, decimal(65,0) BINARY))->Column#14 over(partition by test.t.a))->Sort->Projection", "[planner:1235]This version of TiDB doesn't yet support 'GROUPS'", "[planner:3584]Window '': frame start cannot be UNBOUNDED FOLLOWING.", @@ -300,7 +300,7 @@ "[planner:1210]Incorrect arguments to nth_value", "[planner:1210]Incorrect arguments to ntile", "IndexReader(Index(t.f)[[NULL,+inf]])->Window(ntile()->Column#14 over())->Projection", - "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(65,30) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", + "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a, decimal(15,4) BINARY))->Column#14 over(partition by test.t.b))->Partition(execution info: concurrency:4, data source:TableReader_10)->Projection", "TableReader(Table(t))->Window(nth_value(test.t.i_date, 1)->Column#14 over())->Projection", "TableReader(Table(t))->Window(sum(cast(test.t.b, decimal(65,0) BINARY))->Column#15, sum(cast(test.t.c, decimal(65,0) BINARY))->Column#16 over(order by test.t.a asc range between unbounded preceding and current row))->Projection", "[planner:3593]You cannot use the window function 'sum' in this context.'", From b118f7063ae14f43a8c59b6744980818707e936c Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 18:13:17 +0800 Subject: [PATCH 70/85] executor: refineArgs() bug fix when compare int with very small decimal (#23694) (#23705) --- expression/builtin_compare.go | 15 +++++++++++---- expression/builtin_compare_test.go | 2 +- expression/expression_test.go | 7 +++++-- expression/integration_test.go | 9 +++++++++ planner/core/cbo_test.go | 2 +- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 957d9f366c346..671073dfcf37e 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1344,8 +1344,13 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express // int non-constant [cmp] non-int constant if arg0IsInt && !arg0IsCon && !arg1IsInt && arg1IsCon { arg1, isExceptional = RefineComparedConstant(ctx, *arg0Type, arg1, c.op) - finalArg1 = arg1 - if isExceptional && arg1.GetType().EvalType() == types.ETInt { + // Why check not null flag + // eg: int_col > const_val(which is less than min_int32) + // If int_col got null, compare result cannot be true + if !isExceptional || (isExceptional && mysql.HasNotNullFlag(arg0Type.Flag)) { + finalArg1 = arg1 + } + if isExceptional && arg1.GetType().EvalType() == types.ETInt && mysql.HasNotNullFlag(arg0Type.Flag) { // Judge it is inf or -inf // For int: // inf: 01111111 & 1 == 1 @@ -1363,8 +1368,10 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express // non-int constant [cmp] int non-constant if arg1IsInt && !arg1IsCon && !arg0IsInt && arg0IsCon { arg0, isExceptional = RefineComparedConstant(ctx, *arg1Type, arg0, symmetricOp[c.op]) - finalArg0 = arg0 - if isExceptional && arg0.GetType().EvalType() == types.ETInt { + if !isExceptional || (isExceptional && mysql.HasNotNullFlag(arg1Type.Flag)) { + finalArg0 = arg0 + } + if isExceptional && arg0.GetType().EvalType() == types.ETInt && mysql.HasNotNullFlag(arg1Type.Flag) { if arg0.Value.GetInt64()&1 == 1 { isNegativeInfinite = true } else { diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index f8b8e6f08843c..1e3908b856030 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -27,7 +27,7 @@ import ( ) func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) { - tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong).build() + tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build() tests := []struct { exprStr string result string diff --git a/expression/expression_test.go b/expression/expression_test.go index 2ecdc12f465eb..40eed2c946207 100644 --- a/expression/expression_test.go +++ b/expression/expression_test.go @@ -35,7 +35,7 @@ func (s *testEvaluatorSuite) TestNewValuesFunc(c *C) { } func (s *testEvaluatorSuite) TestEvaluateExprWithNull(c *C) { - tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong).add("col1", mysql.TypeLonglong).build() + tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong, 0).add("col1", mysql.TypeLonglong, 0).build() schema := tableInfoToSchemaForTest(tblInfo) col0 := schema.Columns[0] col1 := schema.Columns[1] @@ -142,15 +142,17 @@ type testTableBuilder struct { tableName string columnNames []string tps []byte + flags []uint } func newTestTableBuilder(tableName string) *testTableBuilder { return &testTableBuilder{tableName: tableName} } -func (builder *testTableBuilder) add(name string, tp byte) *testTableBuilder { +func (builder *testTableBuilder) add(name string, tp byte, flag uint) *testTableBuilder { builder.columnNames = append(builder.columnNames, name) builder.tps = append(builder.tps, tp) + builder.flags = append(builder.flags, flag) return builder } @@ -165,6 +167,7 @@ func (builder *testTableBuilder) build() *model.TableInfo { fieldType := types.NewFieldType(tp) fieldType.Flen, fieldType.Decimal = mysql.GetDefaultFieldLengthAndDecimal(tp) fieldType.Charset, fieldType.Collate = types.DefaultCharsetForType(tp) + fieldType.Flag = builder.flags[i] ti.Columns = append(ti.Columns, &model.ColumnInfo{ ID: int64(i + 1), Name: model.NewCIStr(colName), diff --git a/expression/integration_test.go b/expression/integration_test.go index dae8046c33b7f..ae7f4eddea2de 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8121,6 +8121,15 @@ func (s *testIntegrationSuite) TestIssue23889(c *C) { testkit.Rows("", "0")) } +func (s *testIntegrationSuite) TestIssue23623(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(-2147483648), (-2147483648), (null);") + tk.MustQuery("select count(*) from t1 where c1 > (select sum(c1) from t1);").Check(testkit.Rows("2")) +} + func (s *testIntegrationSerialSuite) TestCollationForBinaryLiteral(c *C) { tk := testkit.NewTestKit(c, s.store) collate.SetNewCollationEnabledForTest(true) diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 3d42bcf1fc8b1..1cd6e4d0037c0 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -444,7 +444,7 @@ func (s *testAnalyzeSuite) TestPreparedNullParam(c *C) { testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") testKit.MustExec("drop table if exists t") - testKit.MustExec("create table t (id int, KEY id (id))") + testKit.MustExec("create table t (id int not null, KEY id (id))") testKit.MustExec("insert into t values (1), (2), (3)") sql := "select * from t where id = ?" From 81db417733c7f23f48ac5b877d0ad4977f90ff3d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 19:24:50 +0800 Subject: [PATCH 71/85] types: fix type merge about bit type (#23857) (#24026) --- executor/executor_test.go | 6 +++ types/field_type.go | 89 ++++++++++++++++++++------------------- types/field_type_test.go | 4 +- 3 files changed, 54 insertions(+), 45 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index d12591ec9e7cf..d3027057f3ff5 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1399,6 +1399,12 @@ func (s *testSuiteP2) TestUnion(c *C) { tk.MustExec("create table t(a int, b decimal(6, 3))") tk.MustExec("insert into t values(1, 1.000)") tk.MustQuery("select count(distinct a), sum(distinct a), avg(distinct a) from (select a from t union all select b from t) tmp;").Check(testkit.Rows("1 1.000 1.0000000")) + + // #issue 23832 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bit(20), b float, c double, d int)") + tk.MustExec("insert into t values(10, 10, 10, 10), (1, -1, 2, -2), (2, -2, 1, 1), (2, 1.1, 2.1, 10.1)") + tk.MustQuery("select a from t union select 10 order by a").Check(testkit.Rows("1", "2", "10")) } func (s *testSuite2) TestUnionLimit(c *C) { diff --git a/types/field_type.go b/types/field_type.go index 0cc65b60f2e53..5804e38c3069e 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -330,7 +330,7 @@ func DefaultCharsetForType(tp byte) (string, string) { // This is used in hybrid field type expression. // For example "select case c when 1 then 2 when 2 then 'tidb' from t;" // The result field type of the case expression is the merged type of the two when clause. -// See https://github.com/mysql/mysql-server/blob/5.7/sql/field.cc#L1042 +// See https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L1042 func MergeFieldType(a byte, b byte) byte { ia := getFieldTypeIndex(a) ib := getFieldTypeIndex(b) @@ -358,6 +358,7 @@ const ( fieldTypeNum = fieldTypeTearFrom + (255 - fieldTypeTearTo) ) +// https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L248 var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ /* mysql.TypeDecimal -> */ { @@ -410,9 +411,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeTiny, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -443,9 +444,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeShort, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -476,9 +477,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeLong, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -509,9 +510,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeFloat, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeDouble, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeDouble, mysql.TypeVarchar, @@ -542,9 +543,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeDouble, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeDouble, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeDouble, mysql.TypeVarchar, @@ -641,9 +642,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeLonglong, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeNewDate, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -674,9 +675,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeInt24, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeNewDate, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -806,9 +807,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeYear, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeLonglong, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, @@ -889,29 +890,29 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ }, /* mysql.TypeBit -> */ { - //mysql.TypeDecimal mysql.TypeTiny - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeShort mysql.TypeLong - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeFloat mysql.TypeDouble - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeNull mysql.TypeTimestamp + // mysql.TypeUnspecified mysql.TypeTiny + mysql.TypeVarchar, mysql.TypeLonglong, + // mysql.TypeShort mysql.TypeLong + mysql.TypeLonglong, mysql.TypeLonglong, + // mysql.TypeFloat mysql.TypeDouble + mysql.TypeDouble, mysql.TypeDouble, + // mysql.TypeNull mysql.TypeTimestamp mysql.TypeBit, mysql.TypeVarchar, - //mysql.TypeLonglong mysql.TypeInt24 - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeDate mysql.TypeTime - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeDatetime mysql.TypeYear + // mysql.TypeLonglong mysql.TypeInt24 + mysql.TypeLonglong, mysql.TypeLonglong, + // mysql.TypeDate mysql.TypeTime mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeNewDate mysql.TypeVarchar + // mysql.TypeDatetime mysql.TypeYear + mysql.TypeVarchar, mysql.TypeLonglong, + // mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, //mysql.TypeBit <16>-<244> mysql.TypeBit, //mysql.TypeJSON mysql.TypeVarchar, - //mysql.TypeNewDecimal mysql.TypeEnum - mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeSet mysql.TypeTinyBlob + // mysql.TypeNewDecimal mysql.TypeEnum + mysql.TypeNewDecimal, mysql.TypeVarchar, + // mysql.TypeSet mysql.TypeTinyBlob mysql.TypeVarchar, mysql.TypeTinyBlob, //mysql.TypeMediumBlob mysql.TypeLongBlob mysql.TypeMediumBlob, mysql.TypeLongBlob, @@ -971,9 +972,9 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeVarchar, mysql.TypeNewDecimal, //mysql.TypeNewDate mysql.TypeVarchar mysql.TypeVarchar, mysql.TypeVarchar, - //mysql.TypeBit <16>-<244> - mysql.TypeVarchar, - //mysql.TypeJSON + // mysql.TypeBit <16>-<244> + mysql.TypeNewDecimal, + // mysql.TypeJSON mysql.TypeVarchar, //mysql.TypeNewDecimal mysql.TypeEnum mysql.TypeNewDecimal, mysql.TypeVarchar, diff --git a/types/field_type_test.go b/types/field_type_test.go index 4d2583ec97d97..821ef86b36899 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -270,9 +270,11 @@ func (s *testFieldTypeSuite) TestAggFieldType(c *C) { c.Assert(aggTp.Tp, Equals, mysql.TypeDouble) case mysql.TypeTimestamp, mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeVarchar, - mysql.TypeBit, mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, + mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeVarString, mysql.TypeGeometry: c.Assert(aggTp.Tp, Equals, mysql.TypeVarchar) + case mysql.TypeBit: + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) case mysql.TypeString: c.Assert(aggTp.Tp, Equals, mysql.TypeString) case mysql.TypeDecimal, mysql.TypeNewDecimal: From 01119cb2a22e4279f9c3523d7ee94c7f640fe709 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 21:04:21 +0800 Subject: [PATCH 72/85] planner,privilege: requires extra privileges for REPLACE and INSERT ON DUPLICATE statements (#23911) (#23938) --- planner/core/planbuilder.go | 22 ++++++++++++-- privilege/privileges/privileges_test.go | 39 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 8761aadbdb4bf..3b2a6193575a5 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -2237,15 +2237,31 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( return nil, ErrPartitionClauseOnNonpartitioned } + user := b.ctx.GetSessionVars().User var authErr error - if b.ctx.GetSessionVars().User != nil { - authErr = ErrTableaccessDenied.GenWithStackByArgs("INSERT", b.ctx.GetSessionVars().User.AuthUsername, - b.ctx.GetSessionVars().User.AuthHostname, tableInfo.Name.L) + if user != nil { + authErr = ErrTableaccessDenied.GenWithStackByArgs("INSERT", user.AuthUsername, user.AuthHostname, tableInfo.Name.L) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.InsertPriv, tn.DBInfo.Name.L, tableInfo.Name.L, "", authErr) + // `REPLACE INTO` requires both INSERT + DELETE privilege + // `ON DUPLICATE KEY UPDATE` requires both INSERT + UPDATE privilege + var extraPriv mysql.PrivilegeType + if insert.IsReplace { + extraPriv = mysql.DeletePriv + } else if insert.OnDuplicate != nil { + extraPriv = mysql.UpdatePriv + } + if extraPriv != 0 { + if user != nil { + cmd := strings.ToUpper(mysql.Priv2Str[extraPriv]) + authErr = ErrTableaccessDenied.GenWithStackByArgs(cmd, user.AuthUsername, user.AuthHostname, tableInfo.Name.L) + } + b.visitInfo = appendVisitInfo(b.visitInfo, extraPriv, tn.DBInfo.Name.L, tableInfo.Name.L, "", authErr) + } + mockTablePlan := LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) mockTablePlan.SetSchema(insertPlan.tableSchema) mockTablePlan.names = insertPlan.tableColNames diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 8c4f5cd8c9eaa..247e77e7bbcc5 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -902,6 +902,45 @@ func (s *testPrivilegeSuite) TestShowCreateTable(c *C) { mustExec(c, se, `SHOW CREATE TABLE mysql.user`) } +func (s *testPrivilegeSuite) TestReplaceAndInsertOnDuplicate(c *C) { + se := newSession(c, s.store, s.dbName) + mustExec(c, se, `CREATE USER tr_insert`) + mustExec(c, se, `CREATE USER tr_update`) + mustExec(c, se, `CREATE USER tr_delete`) + mustExec(c, se, `CREATE TABLE t1 (a int primary key, b int)`) + mustExec(c, se, `GRANT INSERT ON t1 TO tr_insert`) + mustExec(c, se, `GRANT UPDATE ON t1 TO tr_update`) + mustExec(c, se, `GRANT DELETE ON t1 TO tr_delete`) + + // Restrict the permission to INSERT only. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_insert", Hostname: "localhost", AuthUsername: "tr_insert", AuthHostname: "%"}, nil, nil), IsTrue) + + // REPLACE requires INSERT + DELETE privileges, having INSERT alone is insufficient. + _, err := se.ExecuteInternal(context.Background(), `REPLACE INTO t1 VALUES (1, 2)`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]DELETE command denied to user 'tr_insert'@'%' for table 't1'") + + // INSERT ON DUPLICATE requires INSERT + UPDATE privileges, having INSERT alone is insufficient. + _, err = se.ExecuteInternal(context.Background(), `INSERT INTO t1 VALUES (3, 4) ON DUPLICATE KEY UPDATE b = 5`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]UPDATE command denied to user 'tr_insert'@'%' for table 't1'") + + // Plain INSERT should work. + mustExec(c, se, `INSERT INTO t1 VALUES (6, 7)`) + + // Also check that having DELETE alone is insufficient for REPLACE. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_delete", Hostname: "localhost", AuthUsername: "tr_delete", AuthHostname: "%"}, nil, nil), IsTrue) + _, err = se.ExecuteInternal(context.Background(), `REPLACE INTO t1 VALUES (8, 9)`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]INSERT command denied to user 'tr_delete'@'%' for table 't1'") + + // Also check that having UPDATE alone is insufficient for INSERT ON DUPLICATE. + c.Assert(se.Auth(&auth.UserIdentity{Username: "tr_update", Hostname: "localhost", AuthUsername: "tr_update", AuthHostname: "%"}, nil, nil), IsTrue) + _, err = se.ExecuteInternal(context.Background(), `INSERT INTO t1 VALUES (10, 11) ON DUPLICATE KEY UPDATE b = 12`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + c.Assert(err.Error(), Equals, "[planner:1142]INSERT command denied to user 'tr_update'@'%' for table 't1'") +} + func (s *testPrivilegeSuite) TestAnalyzeTable(c *C) { se := newSession(c, s.store, s.dbName) From a5ccf15802fb1d84dfdce5625f7bb1b3a7ecc2cb Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 21:33:23 +0800 Subject: [PATCH 73/85] executor: fix 2nd index dup check after insert ignore on dup update primary (#23814) (#23825) --- executor/write.go | 8 +++++ executor/write_test.go | 22 ++++++++++++++ table/tables/index.go | 9 ++++++ table/tables/tables.go | 68 ++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 104 insertions(+), 3 deletions(-) diff --git a/executor/write.go b/executor/write.go index d607167ac4fe7..06fed54c90e36 100644 --- a/executor/write.go +++ b/executor/write.go @@ -186,6 +186,14 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h int64, oldData } return false, handleChanged, newHandle, err } + err = tables.CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore(ctx, sctx, t, newHandle, newData, modified) + if err != nil { + if terr, ok := errors.Cause(err).(*terror.Error); sctx.GetSessionVars().StmtCtx.IgnoreNoPartition && ok && terr.Code() == errno.ErrNoPartitionForGivenValue { + //return false, nil + return false, handleChanged, newHandle, nil + } + return false, handleChanged, newHandle, err + } } if err = t.RemoveRecord(sctx, h, oldData); err != nil { return false, false, 0, err diff --git a/executor/write_test.go b/executor/write_test.go index 4120fafea4a67..1dc1c9b389ae9 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -761,6 +761,28 @@ func (s *testSuite4) TestInsertIgnoreOnDup(c *C) { testSQL = `select * from t;` r = tk.MustQuery(testSQL) r.Check(testkit.Rows("1 1", "2 2")) + + tk.MustExec("drop table if exists t4") + tk.MustExec("create table t4(id int primary key, k int, v int, unique key uk1(k))") + tk.MustExec("insert into t4 values (1, 10, 100), (3, 30, 300)") + tk.MustExec("insert ignore into t4 (id, k, v) values(1, 0, 0) on duplicate key update id = 2, k = 30") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '30' for key 'uk1'")) + tk.MustQuery("select * from t4").Check(testkit.Rows("1 10 100", "3 30 300")) + + tk.MustExec("drop table if exists t5") + tk.MustExec("create table t5(k2 int primary key, uk1 int, v int, unique key ukk1(uk1), unique key ukk2(v))") + tk.MustExec("insert into t5(k2, uk1, v) values(1, 1, '100'), (3, 2, '200')") + tk.MustExec("update ignore t5 set k2 = '2', uk1 = 2 where k2 = 1") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1062 Duplicate entry '2' for key 'ukk1'")) + tk.MustQuery("select * from t5").Check(testkit.Rows("1 1 100", "3 2 200")) + + tk.MustExec("drop table if exists t6") + tk.MustExec("create table t6 (a int, b int, c int, primary key(a), unique key idx_14(b), unique key idx_15(b), unique key idx_16(a, b))") + tk.MustExec("insert into t6 select 10, 10, 20") + tk.MustExec("insert ignore into t6 set a = 20, b = 10 on duplicate key update a = 100") + tk.MustQuery("select * from t6").Check(testkit.Rows("100 10 20")) + tk.MustExec("insert ignore into t6 set a = 200, b= 10 on duplicate key update c = 1000") + tk.MustQuery("select * from t6").Check(testkit.Rows("100 10 1000")) } func (s *testSuite4) TestInsertSetWithDefault(c *C) { diff --git a/table/tables/index.go b/table/tables/index.go index 5e4ad9ee4c1bc..4536414f019be 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -376,3 +376,12 @@ func (c *index) FetchValues(r []types.Datum, vals []types.Datum) ([]types.Datum, } return vals, nil } + +// IsIndexWritable check whether the index is writable. +func IsIndexWritable(idx table.Index) bool { + s := idx.Meta().State + if s != model.StateDeleteOnly && s != model.StateDeleteReorganization { + return true + } + return false +} diff --git a/table/tables/tables.go b/table/tables/tables.go index 51230c2446875..622a35d0ea8a3 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -606,7 +606,7 @@ func (t *TableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. } // genIndexKeyStr generates index content string representation. -func (t *TableCommon) genIndexKeyStr(colVals []types.Datum) (string, error) { +func genIndexKeyStr(colVals []types.Datum) (string, error) { // Pass pre-composed error to txn. strVals := make([]string, 0, len(colVals)) for _, cv := range colVals { @@ -658,7 +658,7 @@ func (t *TableCommon) addIndices(sctx sessionctx.Context, recordID int64, r []ty } var dupErr error if !skipCheck && v.Meta().Unique { - entryKey, err := t.genIndexKeyStr(indexVals) + entryKey, err := genIndexKeyStr(indexVals) if err != nil { return 0, err } @@ -913,7 +913,7 @@ func (t *TableCommon) buildIndexForRow(ctx sessionctx.Context, rm kv.RetrieverMu if _, err := idx.Create(ctx, rm, vals, h, opts...); err != nil { if kv.ErrKeyExists.Equal(err) { // Make error message consistent with MySQL. - entryKey, err1 := t.genIndexKeyStr(vals) + entryKey, err1 := genIndexKeyStr(vals) if err1 != nil { // if genIndexKeyStr failed, return the original error. return err @@ -1232,6 +1232,68 @@ func CheckHandleExists(ctx context.Context, sctx sessionctx.Context, t table.Tab return nil } +// CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore check whether recordID key or unique index key exists. if not exists, return nil, +// otherwise return kv.ErrKeyExists error. +func CheckUniqueKeyExistForUpdateIgnoreOrInsertOnDupIgnore(ctx context.Context, sctx sessionctx.Context, t table.Table, recordID int64, data []types.Datum, modified []bool) error { + if pt, ok := t.(*partitionedTable); ok { + info := t.Meta().GetPartitionInfo() + pid, err := pt.locatePartition(sctx, info, data) + if err != nil { + return err + } + t = pt.GetPartition(pid) + } + txn, err := sctx.Txn(true) + if err != nil { + return err + } + shouldSkipIgnoreCheck := func(idx table.Index) bool { + if !IsIndexWritable(idx) || !idx.Meta().Unique { + return true + } + for _, c := range idx.Meta().Columns { + if modified[c.Offset] { + return false + } + } + return true + } + for _, idx := range t.Indices() { + if shouldSkipIgnoreCheck(idx) { + continue + } + vals, err := idx.FetchValues(data, nil) + if err != nil { + return err + } + key, _, err := idx.GenIndexKey(sctx.GetSessionVars().StmtCtx, vals, recordID, nil) + if err != nil { + return err + } + entryKey, err := genIndexKeyStr(vals) + if err != nil { + return err + } + err = func() error { + existErrInfo := kv.NewExistErrInfo(idx.Meta().Name.String(), entryKey) + txn.SetOption(kv.PresumeKeyNotExistsError, existErrInfo) + txn.SetOption(kv.CheckExists, sctx.GetSessionVars().StmtCtx.CheckKeyExists) + defer txn.DelOption(kv.PresumeKeyNotExistsError) + _, err = txn.Get(ctx, key) + if err == nil { + return existErrInfo.Err() + } else if !kv.ErrNotExist.Equal(err) { + return err + } + return nil + }() + if err != nil { + return err + } + } + return nil +} + func init() { table.TableFromMeta = TableFromMeta table.MockTableFromMeta = MockTableFromMeta From 877fc4cd59cfb91678d99ffcd51cd4a5ff103a4c Mon Sep 17 00:00:00 2001 From: ShuNing Date: Mon, 10 May 2021 21:57:21 +0800 Subject: [PATCH 74/85] store/tikv: increase batch split region limit (#24508) Signed-off-by: nolouch --- store/tikv/split_region.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/store/tikv/split_region.go b/store/tikv/split_region.go index a9ce30affd95a..6a22945603032 100644 --- a/store/tikv/split_region.go +++ b/store/tikv/split_region.go @@ -33,7 +33,7 @@ import ( "go.uber.org/zap" ) -const splitBatchRegionLimit = 16 +const splitBatchRegionLimit = 2048 func equalRegionStartKey(key, regionStartKey []byte) bool { return bytes.Equal(key, regionStartKey) From fa54cb2df8f5f0bc0857080c6a722f72501d5421 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 22:16:35 +0800 Subject: [PATCH 75/85] *: Adapt ScanDetailV2 in KvGet and KvBatchGet Response (#21562) (#24289) --- distsql/select_result.go | 14 +- executor/adapter.go | 8 +- executor/explainfor_test.go | 4 +- sessionctx/stmtctx/stmtctx.go | 40 ++++-- sessionctx/stmtctx/stmtctx_test.go | 6 +- sessionctx/variable/session_test.go | 16 ++- store/tikv/coprocessor.go | 23 +++- store/tikv/snapshot.go | 34 +++++ store/tikv/snapshot_test.go | 22 +++ util/execdetails/execdetails.go | 148 +++++++++++++++++---- util/execdetails/execdetails_test.go | 85 ++++++++++-- util/stmtsummary/statement_summary.go | 29 ++-- util/stmtsummary/statement_summary_test.go | 84 +++++++----- 13 files changed, 397 insertions(+), 116 deletions(-) diff --git a/distsql/select_result.go b/distsql/select_result.go index 6952a9e792f6f..825882e2887f6 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -277,6 +277,10 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *tikv } r.stats.mergeCopRuntimeStats(copStats, respTime) + if copStats.ScanDetail != nil && len(r.copPlanIDs) > 0 { + r.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], copStats.ScanDetail) + } + for i, detail := range r.selectResp.GetExecutionSummaries() { if detail != nil && detail.TimeProcessedNs != nil && detail.NumProducedRows != nil && detail.NumIterations != nil { @@ -338,13 +342,17 @@ type selectResultRuntimeStats struct { func (s *selectResultRuntimeStats) mergeCopRuntimeStats(copStats *tikv.CopRuntimeStats, respTime time.Duration) { s.copRespTime = append(s.copRespTime, respTime) - s.procKeys = append(s.procKeys, copStats.ProcessedKeys) + if copStats.ScanDetail != nil { + s.procKeys = append(s.procKeys, copStats.ScanDetail.ProcessedKeys) + } else { + s.procKeys = append(s.procKeys, 0) + } for k, v := range copStats.BackoffSleep { s.backoffSleep[k] += v } - s.totalProcessTime += copStats.ProcessTime - s.totalWaitTime += copStats.WaitTime + s.totalProcessTime += copStats.TimeDetail.ProcessTime + s.totalWaitTime += copStats.TimeDetail.WaitTime s.rpcStat.Merge(copStats.RegionRequestRuntimeStats) if copStats.CoprCacheHit { s.CoprCacheHitNum++ diff --git a/executor/adapter.go b/executor/adapter.go index 35476e1d22e75..1831f3c4d02ab 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -944,12 +944,12 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { logutil.SlowQueryLogger.Warn(sessVars.SlowLogFormat(slowItems)) if sessVars.InRestrictedSQL { totalQueryProcHistogramInternal.Observe(costTime.Seconds()) - totalCopProcHistogramInternal.Observe(execDetail.ProcessTime.Seconds()) - totalCopWaitHistogramInternal.Observe(execDetail.WaitTime.Seconds()) + totalCopProcHistogramInternal.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + totalCopWaitHistogramInternal.Observe(execDetail.TimeDetail.WaitTime.Seconds()) } else { totalQueryProcHistogramGeneral.Observe(costTime.Seconds()) - totalCopProcHistogramGeneral.Observe(execDetail.ProcessTime.Seconds()) - totalCopWaitHistogramGeneral.Observe(execDetail.WaitTime.Seconds()) + totalCopProcHistogramGeneral.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + totalCopWaitHistogramGeneral.Observe(execDetail.TimeDetail.WaitTime.Seconds()) } var userString string if sessVars.User != nil { diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index bdadf3e0fac92..efa72e7e2ca2d 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -104,8 +104,8 @@ func (s *testSuite9) TestExplainFor(c *C) { } } c.Assert(buf.String(), Matches, ""+ - "TableReader_5 10000.00 0 root time:.*, loops:1, cop_task:.*num: 1, max:.*, proc_keys: 0, rpc_num: 1, rpc_time: .*data:TableFullScan_4 N/A N/A\n"+ - "└─TableFullScan_4 10000.00 0 cop.* table:t1 .*keep order:false, stats:pseudo N/A N/A") + "TableReader_5 10000.00 0 root time:.*, loops:1, cop_task: {num:.*, max:.*, proc_keys: 0, rpc_num: 1, rpc_time:.*} data:TableFullScan_4 N/A N/A\n"+ + "└─TableFullScan_4 10000.00 0 cop.* table:t1 tikv_task:{time:.*, loops:.*} keep order:false, stats:pseudo N/A N/A") } tkRoot.MustQuery("select * from t1;") check() diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index fbba707a9cc77..7832d717d7e35 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -511,12 +511,10 @@ func (sc *StatementContext) MergeExecDetails(details *execdetails.ExecDetails, c sc.mu.Lock() if details != nil { sc.mu.execDetails.CopTime += details.CopTime - sc.mu.execDetails.ProcessTime += details.ProcessTime - sc.mu.execDetails.WaitTime += details.WaitTime sc.mu.execDetails.BackoffTime += details.BackoffTime sc.mu.execDetails.RequestCount++ - sc.mu.execDetails.TotalKeys += details.TotalKeys - sc.mu.execDetails.ProcessedKeys += details.ProcessedKeys + sc.MergeScanDetail(details.ScanDetail) + sc.MergeTimeDetail(details.TimeDetail) sc.mu.allExecDetails = append(sc.mu.allExecDetails, details) } if commitDetails != nil { @@ -529,6 +527,24 @@ func (sc *StatementContext) MergeExecDetails(details *execdetails.ExecDetails, c sc.mu.Unlock() } +// MergeScanDetail merges scan details into self. +func (sc *StatementContext) MergeScanDetail(scanDetail *execdetails.ScanDetail) { + // Currently TiFlash cop task does not fill scanDetail, so need to skip it if scanDetail is nil + if scanDetail == nil { + return + } + if sc.mu.execDetails.ScanDetail == nil { + sc.mu.execDetails.ScanDetail = &execdetails.ScanDetail{} + } + sc.mu.execDetails.ScanDetail.Merge(scanDetail) +} + +// MergeTimeDetail merges time details into self. +func (sc *StatementContext) MergeTimeDetail(timeDetail execdetails.TimeDetail) { + sc.mu.execDetails.TimeDetail.ProcessTime += timeDetail.ProcessTime + sc.mu.execDetails.TimeDetail.WaitTime += timeDetail.WaitTime +} + // MergeLockKeysExecDetails merges lock keys execution details into self. func (sc *StatementContext) MergeLockKeysExecDetails(lockKeys *execdetails.LockKeysDetails) { sc.mu.Lock() @@ -615,21 +631,21 @@ func (sc *StatementContext) CopTasksDetails() *CopTasksDetails { if n == 0 { return d } - d.AvgProcessTime = sc.mu.execDetails.ProcessTime / time.Duration(n) - d.AvgWaitTime = sc.mu.execDetails.WaitTime / time.Duration(n) + d.AvgProcessTime = sc.mu.execDetails.TimeDetail.ProcessTime / time.Duration(n) + d.AvgWaitTime = sc.mu.execDetails.TimeDetail.WaitTime / time.Duration(n) sort.Slice(sc.mu.allExecDetails, func(i, j int) bool { - return sc.mu.allExecDetails[i].ProcessTime < sc.mu.allExecDetails[j].ProcessTime + return sc.mu.allExecDetails[i].TimeDetail.ProcessTime < sc.mu.allExecDetails[j].TimeDetail.ProcessTime }) - d.P90ProcessTime = sc.mu.allExecDetails[n*9/10].ProcessTime - d.MaxProcessTime = sc.mu.allExecDetails[n-1].ProcessTime + d.P90ProcessTime = sc.mu.allExecDetails[n*9/10].TimeDetail.ProcessTime + d.MaxProcessTime = sc.mu.allExecDetails[n-1].TimeDetail.ProcessTime d.MaxProcessAddress = sc.mu.allExecDetails[n-1].CalleeAddress sort.Slice(sc.mu.allExecDetails, func(i, j int) bool { - return sc.mu.allExecDetails[i].WaitTime < sc.mu.allExecDetails[j].WaitTime + return sc.mu.allExecDetails[i].TimeDetail.WaitTime < sc.mu.allExecDetails[j].TimeDetail.WaitTime }) - d.P90WaitTime = sc.mu.allExecDetails[n*9/10].WaitTime - d.MaxWaitTime = sc.mu.allExecDetails[n-1].WaitTime + d.P90WaitTime = sc.mu.allExecDetails[n*9/10].TimeDetail.WaitTime + d.MaxWaitTime = sc.mu.allExecDetails[n-1].TimeDetail.WaitTime d.MaxWaitAddress = sc.mu.allExecDetails[n-1].CalleeAddress // calculate backoff details diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index c786d5e10fcbf..cd0a51800a471 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -37,10 +37,12 @@ func (s *stmtctxSuit) TestCopTasksDetails(c *C) { for i := 0; i < 100; i++ { d := &execdetails.ExecDetails{ CalleeAddress: fmt.Sprintf("%v", i+1), - ProcessTime: time.Second * time.Duration(i+1), - WaitTime: time.Millisecond * time.Duration(i+1), BackoffSleep: make(map[string]time.Duration), BackoffTimes: make(map[string]int), + TimeDetail: execdetails.TimeDetail{ + ProcessTime: time.Second * time.Duration(i+1), + WaitTime: time.Millisecond * time.Duration(i+1), + }, } for _, backoff := range backoffs { d.BackoffSleep[backoff] = time.Millisecond * 100 * time.Duration(i+1) diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index 3e13fd0b0ba3e..7419c4d6d07f2 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -136,12 +136,16 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { txnTS := uint64(406649736972468225) costTime := time.Second execDetail := execdetails.ExecDetails{ - ProcessTime: time.Second * time.Duration(2), - WaitTime: time.Minute, - BackoffTime: time.Millisecond, - RequestCount: 2, - TotalKeys: 10000, - ProcessedKeys: 20001, + BackoffTime: time.Millisecond, + RequestCount: 2, + ScanDetail: &execdetails.ScanDetail{ + ProcessedKeys: 20001, + TotalKeys: 10000, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: time.Second * time.Duration(2), + WaitTime: time.Minute, + }, } statsInfos := make(map[string]uint64) statsInfos["t1"] = 0 diff --git a/store/tikv/coprocessor.go b/store/tikv/coprocessor.go index 7dbdcda6e0309..2846159e89331 100644 --- a/store/tikv/coprocessor.go +++ b/store/tikv/coprocessor.go @@ -1149,18 +1149,29 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon resp.detail.CalleeAddress = rpcCtx.Addr } resp.respTime = costTime - if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { + sd := &execdetails.ScanDetail{} + td := execdetails.TimeDetail{} + if pbDetails := resp.pbResp.ExecDetailsV2; pbDetails != nil { + // Take values in `ExecDetailsV2` first. if timeDetail := pbDetails.TimeDetail; timeDetail != nil { - resp.detail.WaitTime = time.Duration(timeDetail.WaitWallTimeMs) * time.Millisecond - resp.detail.ProcessTime = time.Duration(timeDetail.ProcessWallTimeMs) * time.Millisecond + td.MergeFromTimeDetail(timeDetail) + } + if scanDetailV2 := pbDetails.ScanDetailV2; scanDetailV2 != nil { + sd.MergeFromScanDetailV2(scanDetailV2) + } + } else if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { + if timeDetail := pbDetails.TimeDetail; timeDetail != nil { + td.MergeFromTimeDetail(timeDetail) } if scanDetail := pbDetails.ScanDetail; scanDetail != nil { if scanDetail.Write != nil { - resp.detail.TotalKeys += scanDetail.Write.Total - resp.detail.ProcessedKeys += scanDetail.Write.Processed + sd.ProcessedKeys = scanDetail.Write.Processed + sd.TotalKeys = scanDetail.Write.Total } } } + resp.detail.ScanDetail = sd + resp.detail.TimeDetail = td if resp.pbResp.IsCacheHit { if cacheValue == nil { return nil, errors.New("Internal error: received illegal TiKV response") @@ -1173,7 +1184,7 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon } else { // Cache not hit or cache hit but not valid: update the cache if the response can be cached. if cacheKey != nil && resp.pbResp.CanBeCached && resp.pbResp.CacheLastVersion > 0 { - if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.ProcessTime) { + if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.TimeDetail.ProcessTime) { data := make([]byte, len(resp.pbResp.Data)) copy(data, resp.pbResp.Data) diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index d735527919088..d6b798970665e 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -324,6 +324,9 @@ func (s *tikvSnapshot) batchGetSingleRegion(bo *Backoffer, batch batchKeys, coll lockedKeys = append(lockedKeys, lock.Key) locks = append(locks, lock) } + if batchGetResp.ExecDetailsV2 != nil { + s.mergeExecDetail(batchGetResp.ExecDetailsV2) + } if len(lockedKeys) > 0 { msBeforeExpired, err := cli.ResolveLocks(bo, s.version.Ver, locks) if err != nil { @@ -436,6 +439,9 @@ func (s *tikvSnapshot) get(ctx context.Context, bo *Backoffer, k kv.Key) ([]byte return nil, errors.Trace(ErrBodyMissing) } cmdGetResp := resp.Resp.(*pb.GetResponse) + if cmdGetResp.ExecDetailsV2 != nil { + s.mergeExecDetail(cmdGetResp.ExecDetailsV2) + } val := cmdGetResp.GetValue() if keyErr := cmdGetResp.GetError(); keyErr != nil { lock, err := extractLockFromKeyErr(keyErr) @@ -458,6 +464,22 @@ func (s *tikvSnapshot) get(ctx context.Context, bo *Backoffer, k kv.Key) ([]byte } } +func (s *tikvSnapshot) mergeExecDetail(detail *pb.ExecDetailsV2) { + s.mu.Lock() + defer s.mu.Unlock() + if detail == nil || s.mu.stats == nil { + return + } + if s.mu.stats.scanDetail == nil { + s.mu.stats.scanDetail = &execdetails.ScanDetail{} + } + if s.mu.stats.timeDetail == nil { + s.mu.stats.timeDetail = &execdetails.TimeDetail{} + } + s.mu.stats.scanDetail.MergeFromScanDetailV2(detail.ScanDetailV2) + s.mu.stats.timeDetail.MergeFromTimeDetail(detail.TimeDetail) +} + // Iter return a list of key-value pair after `k`. func (s *tikvSnapshot) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) { scanner, err := newScanner(s, k, upperBound, scanBatchSize, false) @@ -659,6 +681,8 @@ type SnapshotRuntimeStats struct { rpcStats RegionRequestRuntimeStats backoffSleepMS map[backoffType]int backoffTimes map[backoffType]int + scanDetail *execdetails.ScanDetail + timeDetail *execdetails.TimeDetail } // Tp implements the RuntimeStats interface. @@ -727,5 +751,15 @@ func (rs *SnapshotRuntimeStats) String() string { d := time.Duration(ms) * time.Millisecond buf.WriteString(fmt.Sprintf("%s_backoff:{num:%d, total_time:%s}", k.String(), v, execdetails.FormatDuration(d))) } + timeDetail := rs.timeDetail.String() + if timeDetail != "" { + buf.WriteString(", ") + buf.WriteString(timeDetail) + } + scanDetail := rs.scanDetail.String() + if scanDetail != "" { + buf.WriteString(", ") + buf.WriteString(scanDetail) + } return buf.String() } diff --git a/store/tikv/snapshot_test.go b/store/tikv/snapshot_test.go index 87e883c7f595a..a3b51d8b86ab5 100644 --- a/store/tikv/snapshot_test.go +++ b/store/tikv/snapshot_test.go @@ -317,4 +317,26 @@ func (s *testSnapshotSuite) TestSnapshotRuntimeStats(c *C) { snapshot.recordBackoffInfo(bo) expect := "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}" c.Assert(snapshot.mu.stats.String(), Equals, expect) + detail := &pb.ExecDetailsV2{ + TimeDetail: &pb.TimeDetail{ + WaitWallTimeMs: 100, + ProcessWallTimeMs: 100, + }, + ScanDetailV2: &pb.ScanDetailV2{ + ProcessedVersions: 10, + TotalVersions: 15, + }, + } + snapshot.mergeExecDetail(detail) + expect = "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}, " + + "total_process_time: 100ms, total_wait_time: 100ms, " + + "scan_detail: {total_process_keys: 10, " + + "total_keys: 15}" + c.Assert(snapshot.mu.stats.String(), Equals, expect) + snapshot.mergeExecDetail(detail) + expect = "Get:{num_rpc:4, total_time:2s},txnLockFast_backoff:{num:2, total_time:60ms}, " + + "total_process_time: 200ms, total_wait_time: 200ms, " + + "scan_detail: {total_process_keys: 20, " + + "total_keys: 30}" + c.Assert(snapshot.mu.stats.String(), Equals, expect) } diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 1d8588982364e..537af6599d699 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -24,6 +24,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" ) @@ -43,17 +44,15 @@ var ( type ExecDetails struct { CalleeAddress string CopTime time.Duration - ProcessTime time.Duration - WaitTime time.Duration BackoffTime time.Duration LockKeysDuration time.Duration BackoffSleep map[string]time.Duration BackoffTimes map[string]int RequestCount int - TotalKeys int64 - ProcessedKeys int64 CommitDetail *CommitDetails LockKeysDetail *LockKeysDetails + ScanDetail *ScanDetail + TimeDetail TimeDetail } type stmtExecDetailKeyType struct{} @@ -169,6 +168,89 @@ func (ld *LockKeysDetails) Clone() *LockKeysDetails { return lock } +// TimeDetail contains coprocessor time detail information. +type TimeDetail struct { + // WaitWallTimeMs is the off-cpu wall time which is elapsed in TiKV side. Usually this includes queue waiting time and + // other kind of waitings in series. + ProcessTime time.Duration + // Off-cpu and on-cpu wall time elapsed to actually process the request payload. It does not + // include `wait_wall_time`. + // This field is very close to the CPU time in most cases. Some wait time spend in RocksDB + // cannot be excluded for now, like Mutex wait time, which is included in this field, so that + // this field is called wall time instead of CPU time. + WaitTime time.Duration +} + +// String implements the fmt.Stringer interface. +func (td *TimeDetail) String() string { + if td == nil { + return "" + } + buf := bytes.NewBuffer(make([]byte, 0, 16)) + if td.ProcessTime > 0 { + buf.WriteString("total_process_time: ") + buf.WriteString(FormatDuration(td.ProcessTime)) + } + if td.WaitTime > 0 { + if buf.Len() > 0 { + buf.WriteString(", ") + } + buf.WriteString("total_wait_time: ") + buf.WriteString(FormatDuration(td.WaitTime)) + } + return buf.String() +} + +// MergeFromTimeDetail merges time detail from pb into itself. +func (td *TimeDetail) MergeFromTimeDetail(timeDetail *kvrpcpb.TimeDetail) { + if timeDetail != nil { + td.WaitTime += time.Duration(timeDetail.WaitWallTimeMs) * time.Millisecond + td.ProcessTime += time.Duration(timeDetail.ProcessWallTimeMs) * time.Millisecond + } +} + +// ScanDetail contains coprocessor scan detail information. +type ScanDetail struct { + // TotalKeys is the approximate number of MVCC keys meet during scanning. It includes + // deleted versions, but does not include RocksDB tombstone keys. + TotalKeys int64 + // ProcessedKeys is the number of user keys scanned from the storage. + // It does not include deleted version or RocksDB tombstone keys. + // For Coprocessor requests, it includes keys that has been filtered out by Selection. + ProcessedKeys int64 +} + +// Merge merges scan detail execution details into self. +func (sd *ScanDetail) Merge(scanDetail *ScanDetail) { + atomic.AddInt64(&sd.TotalKeys, scanDetail.TotalKeys) + atomic.AddInt64(&sd.ProcessedKeys, scanDetail.ProcessedKeys) +} + +var zeroScanDetail = ScanDetail{} + +// String implements the fmt.Stringer interface. +func (sd *ScanDetail) String() string { + if sd == nil || *sd == zeroScanDetail { + return "" + } + buf := bytes.NewBuffer(make([]byte, 0, 16)) + buf.WriteString("scan_detail: {") + buf.WriteString("total_process_keys: ") + buf.WriteString(strconv.FormatInt(sd.ProcessedKeys, 10)) + buf.WriteString(", total_keys: ") + buf.WriteString(strconv.FormatInt(sd.TotalKeys, 10)) + buf.WriteString("}") + return buf.String() +} + +// MergeFromScanDetailV2 merges scan detail from pb into itself. +func (sd *ScanDetail) MergeFromScanDetailV2(scanDetail *kvrpcpb.ScanDetailV2) { + if scanDetail != nil { + sd.TotalKeys += int64(scanDetail.TotalVersions) + sd.ProcessedKeys += int64(scanDetail.ProcessedVersions) + } +} + const ( // CopTimeStr represents the sum of cop-task time spend in TiDB distSQL. CopTimeStr = "Cop_time" @@ -218,11 +300,11 @@ func (d ExecDetails) String() string { if d.CopTime > 0 { parts = append(parts, CopTimeStr+": "+strconv.FormatFloat(d.CopTime.Seconds(), 'f', -1, 64)) } - if d.ProcessTime > 0 { - parts = append(parts, ProcessTimeStr+": "+strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)) + if d.TimeDetail.ProcessTime > 0 { + parts = append(parts, ProcessTimeStr+": "+strconv.FormatFloat(d.TimeDetail.ProcessTime.Seconds(), 'f', -1, 64)) } - if d.WaitTime > 0 { - parts = append(parts, WaitTimeStr+": "+strconv.FormatFloat(d.WaitTime.Seconds(), 'f', -1, 64)) + if d.TimeDetail.WaitTime > 0 { + parts = append(parts, WaitTimeStr+": "+strconv.FormatFloat(d.TimeDetail.WaitTime.Seconds(), 'f', -1, 64)) } if d.BackoffTime > 0 { parts = append(parts, BackoffTimeStr+": "+strconv.FormatFloat(d.BackoffTime.Seconds(), 'f', -1, 64)) @@ -233,11 +315,14 @@ func (d ExecDetails) String() string { if d.RequestCount > 0 { parts = append(parts, RequestCountStr+": "+strconv.FormatInt(int64(d.RequestCount), 10)) } - if d.TotalKeys > 0 { - parts = append(parts, TotalKeysStr+": "+strconv.FormatInt(d.TotalKeys, 10)) - } - if d.ProcessedKeys > 0 { - parts = append(parts, ProcessKeysStr+": "+strconv.FormatInt(d.ProcessedKeys, 10)) + scanDetail := d.ScanDetail + if scanDetail != nil { + if scanDetail.TotalKeys > 0 { + parts = append(parts, TotalKeysStr+": "+strconv.FormatInt(scanDetail.TotalKeys, 10)) + } + if scanDetail.ProcessedKeys > 0 { + parts = append(parts, ProcessKeysStr+": "+strconv.FormatInt(scanDetail.ProcessedKeys, 10)) + } } commitDetails := d.CommitDetail if commitDetails != nil { @@ -292,11 +377,11 @@ func (d ExecDetails) ToZapFields() (fields []zap.Field) { if d.CopTime > 0 { fields = append(fields, zap.String(strings.ToLower(CopTimeStr), strconv.FormatFloat(d.CopTime.Seconds(), 'f', -1, 64)+"s")) } - if d.ProcessTime > 0 { - fields = append(fields, zap.String(strings.ToLower(ProcessTimeStr), strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)+"s")) + if d.TimeDetail.ProcessTime > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessTimeStr), strconv.FormatFloat(d.TimeDetail.ProcessTime.Seconds(), 'f', -1, 64)+"s")) } - if d.WaitTime > 0 { - fields = append(fields, zap.String(strings.ToLower(WaitTimeStr), strconv.FormatFloat(d.WaitTime.Seconds(), 'f', -1, 64)+"s")) + if d.TimeDetail.WaitTime > 0 { + fields = append(fields, zap.String(strings.ToLower(WaitTimeStr), strconv.FormatFloat(d.TimeDetail.WaitTime.Seconds(), 'f', -1, 64)+"s")) } if d.BackoffTime > 0 { fields = append(fields, zap.String(strings.ToLower(BackoffTimeStr), strconv.FormatFloat(d.BackoffTime.Seconds(), 'f', -1, 64)+"s")) @@ -304,11 +389,11 @@ func (d ExecDetails) ToZapFields() (fields []zap.Field) { if d.RequestCount > 0 { fields = append(fields, zap.String(strings.ToLower(RequestCountStr), strconv.FormatInt(int64(d.RequestCount), 10))) } - if d.TotalKeys > 0 { - fields = append(fields, zap.String(strings.ToLower(TotalKeysStr), strconv.FormatInt(d.TotalKeys, 10))) + if d.ScanDetail != nil && d.ScanDetail.TotalKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(TotalKeysStr), strconv.FormatInt(d.ScanDetail.TotalKeys, 10))) } - if d.ProcessedKeys > 0 { - fields = append(fields, zap.String(strings.ToLower(ProcessKeysStr), strconv.FormatInt(d.ProcessedKeys, 10))) + if d.ScanDetail != nil && d.ScanDetail.ProcessedKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessKeysStr), strconv.FormatInt(d.ScanDetail.ProcessedKeys, 10))) } commitDetails := d.CommitDetail if commitDetails != nil { @@ -363,7 +448,8 @@ type CopRuntimeStats struct { // have many region leaders, several coprocessor tasks can be sent to the // same tikv-server instance. We have to use a list to maintain all tasks // executed on each instance. - stats map[string][]*BasicRuntimeStats + stats map[string][]*BasicRuntimeStats + scanDetail *ScanDetail } // RecordOneCopTask records a specific cop tasks's execution detail. @@ -412,6 +498,13 @@ func (crs *CopRuntimeStats) String() string { FormatDuration(procTimes[n-1]), FormatDuration(procTimes[0]), FormatDuration(procTimes[n*4/5]), FormatDuration(procTimes[n*19/20]), totalIters, totalTasks)) } + if crs.scanDetail != nil { + detail := crs.scanDetail.String() + if detail != "" { + buf.WriteString(", ") + buf.WriteString(detail) + } + } return buf.String() } @@ -624,7 +717,10 @@ func (e *RuntimeStatsColl) GetCopStats(planID int) *CopRuntimeStats { defer e.mu.Unlock() copStats, ok := e.copStats[planID] if !ok { - copStats = &CopRuntimeStats{stats: make(map[string][]*BasicRuntimeStats)} + copStats = &CopRuntimeStats{ + stats: make(map[string][]*BasicRuntimeStats), + scanDetail: &ScanDetail{}, + } e.copStats[planID] = copStats } return copStats @@ -651,6 +747,12 @@ func (e *RuntimeStatsColl) RecordOneCopTask(planID int, address string, summary copStats.RecordOneCopTask(address, summary) } +// RecordScanDetail records a specific cop tasks's cop detail. +func (e *RuntimeStatsColl) RecordScanDetail(planID int, detail *ScanDetail) { + copStats := e.GetCopStats(planID) + copStats.scanDetail.Merge(detail) +} + // ExistsRootStats checks if the planID exists in the rootStats collection. func (e *RuntimeStatsColl) ExistsRootStats(planID int) bool { e.mu.Lock() diff --git a/util/execdetails/execdetails_test.go b/util/execdetails/execdetails_test.go index af18e52b9be23..0e29f0a7850ef 100644 --- a/util/execdetails/execdetails_test.go +++ b/util/execdetails/execdetails_test.go @@ -15,6 +15,7 @@ package execdetails import ( "fmt" + "strconv" "sync" "testing" "time" @@ -30,13 +31,9 @@ func TestT(t *testing.T) { func TestString(t *testing.T) { detail := &ExecDetails{ - CopTime: time.Second + 3*time.Millisecond, - ProcessTime: 2*time.Second + 5*time.Millisecond, - WaitTime: time.Second, - BackoffTime: time.Second, - RequestCount: 1, - TotalKeys: 100, - ProcessedKeys: 10, + CopTime: time.Second + 3*time.Millisecond, + BackoffTime: time.Second, + RequestCount: 1, CommitDetail: &CommitDetails{ GetCommitTsTime: time.Second, PrewriteTime: time.Second, @@ -60,6 +57,14 @@ func TestString(t *testing.T) { PrewriteRegionNum: 1, TxnRetry: 1, }, + ScanDetail: &ScanDetail{ + ProcessedKeys: 10, + TotalKeys: 100, + }, + TimeDetail: TimeDetail{ + ProcessTime: 2*time.Second + 5*time.Millisecond, + WaitTime: time.Second, + }, } expected := "Cop_time: 1.003 Process_time: 2.005 Wait_time: 1 Backoff_time: 1 Request_count: 1 Total_keys: 100 Process_keys: 10 Prewrite_time: 1 Commit_time: 1 " + "Get_commit_ts_time: 1 Commit_backoff_time: 1 Backoff_types: [backoff1 backoff2] Resolve_lock_time: 1 Local_latch_wait_time: 1 Write_keys: 1 Write_size: 1 Prewrite_region: 1 Txn_retry: 1" @@ -91,12 +96,74 @@ func TestCopRuntimeStats(t *testing.T) { stats.RecordOneCopTask(tableScanID, "8.8.8.9", mockExecutorExecutionSummary(2, 2, 2)) stats.RecordOneCopTask(aggID, "8.8.8.8", mockExecutorExecutionSummary(3, 3, 3)) stats.RecordOneCopTask(aggID, "8.8.8.9", mockExecutorExecutionSummary(4, 4, 4)) + scanDetail := &ScanDetail{ + TotalKeys: 15, + ProcessedKeys: 10, + } + stats.RecordScanDetail(tableScanID, scanDetail) + if stats.ExistsCopStats(tableScanID) != true { + t.Fatal("exist") + } + cop := stats.GetCopStats(tableScanID) + if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}, "+ + "scan_detail: {total_process_keys: 10, total_keys: 15}" { + t.Fatalf(cop.String()) + } + copStats := cop.stats["8.8.8.8"] + if copStats == nil { + t.Fatal("cop stats is nil") + } + copStats[0].SetRowNum(10) + copStats[0].Record(time.Second, 10) + if copStats[0].String() != "time:1s, loops:2" { + t.Fatalf("cop stats string is not expect, got: %v", copStats[0].String()) + } + + if stats.GetCopStats(aggID).String() != "tikv_task:{proc max:4ns, min:3ns, p80:4ns, p95:4ns, iters:7, tasks:2}" { + t.Fatalf("agg cop stats string is not as expected, got: %v", stats.GetCopStats(aggID).String()) + } + rootStats := stats.GetRootStats(tableReaderID) + if rootStats == nil { + t.Fatal("table_reader") + } + if stats.ExistsRootStats(tableReaderID) == false { + t.Fatal("table_reader not exists") + } + + cop.scanDetail.ProcessedKeys = 0 + // Print all fields even though the value of some fields is 0. + if cop.String() != "tikv_task:{proc max:1s, min:2ns, p80:1s, p95:1s, iters:4, tasks:2}, "+ + "scan_detail: {total_process_keys: 0, total_keys: 15}" { + t.Fatalf(cop.String()) + } + + zeroScanDetail := ScanDetail{} + if zeroScanDetail.String() != "" { + t.Fatalf(zeroScanDetail.String()) + } +} + +func TestCopRuntimeStatsForTiFlash(t *testing.T) { + stats := NewRuntimeStatsColl() + tableScanID := 1 + aggID := 2 + tableReaderID := 3 + stats.RecordOneCopTask(aggID, "8.8.8.8", mockExecutorExecutionSummaryForTiFlash(1, 1, 1, "tablescan_"+strconv.Itoa(tableScanID))) + stats.RecordOneCopTask(aggID, "8.8.8.9", mockExecutorExecutionSummaryForTiFlash(2, 2, 2, "tablescan_"+strconv.Itoa(tableScanID))) + stats.RecordOneCopTask(tableScanID, "8.8.8.8", mockExecutorExecutionSummaryForTiFlash(3, 3, 3, "aggregation_"+strconv.Itoa(aggID))) + stats.RecordOneCopTask(tableScanID, "8.8.8.9", mockExecutorExecutionSummaryForTiFlash(4, 4, 4, "aggregation_"+strconv.Itoa(aggID))) + scanDetail := &ScanDetail{ + TotalKeys: 10, + ProcessedKeys: 10, + } + stats.RecordScanDetail(tableScanID, scanDetail) if stats.ExistsCopStats(tableScanID) != true { t.Fatal("exist") } cop := stats.GetCopStats(tableScanID) - if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}" { - t.Fatal("table_scan") + if cop.String() != "tikv_task:{proc max:2ns, min:1ns, p80:2ns, p95:2ns, iters:3, tasks:2}"+ + ", scan_detail: {total_process_keys: 10, total_keys: 10}" { + t.Fatal(cop.String()) } copStats := cop.stats["8.8.8.8"] if copStats == nil { diff --git a/util/stmtsummary/statement_summary.go b/util/stmtsummary/statement_summary.go index 1250a6a8bf780..502b94c874c4f 100644 --- a/util/stmtsummary/statement_summary.go +++ b/util/stmtsummary/statement_summary.go @@ -701,25 +701,28 @@ func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeco } // TiKV - ssElement.sumProcessTime += sei.ExecDetail.ProcessTime - if sei.ExecDetail.ProcessTime > ssElement.maxProcessTime { - ssElement.maxProcessTime = sei.ExecDetail.ProcessTime + ssElement.sumProcessTime += sei.ExecDetail.TimeDetail.ProcessTime + if sei.ExecDetail.TimeDetail.ProcessTime > ssElement.maxProcessTime { + ssElement.maxProcessTime = sei.ExecDetail.TimeDetail.ProcessTime } - ssElement.sumWaitTime += sei.ExecDetail.WaitTime - if sei.ExecDetail.WaitTime > ssElement.maxWaitTime { - ssElement.maxWaitTime = sei.ExecDetail.WaitTime + ssElement.sumWaitTime += sei.ExecDetail.TimeDetail.WaitTime + if sei.ExecDetail.TimeDetail.WaitTime > ssElement.maxWaitTime { + ssElement.maxWaitTime = sei.ExecDetail.TimeDetail.WaitTime } ssElement.sumBackoffTime += sei.ExecDetail.BackoffTime if sei.ExecDetail.BackoffTime > ssElement.maxBackoffTime { ssElement.maxBackoffTime = sei.ExecDetail.BackoffTime } - ssElement.sumTotalKeys += sei.ExecDetail.TotalKeys - if sei.ExecDetail.TotalKeys > ssElement.maxTotalKeys { - ssElement.maxTotalKeys = sei.ExecDetail.TotalKeys - } - ssElement.sumProcessedKeys += sei.ExecDetail.ProcessedKeys - if sei.ExecDetail.ProcessedKeys > ssElement.maxProcessedKeys { - ssElement.maxProcessedKeys = sei.ExecDetail.ProcessedKeys + + if sei.ExecDetail.ScanDetail != nil { + ssElement.sumTotalKeys += sei.ExecDetail.ScanDetail.TotalKeys + if sei.ExecDetail.ScanDetail.TotalKeys > ssElement.maxTotalKeys { + ssElement.maxTotalKeys = sei.ExecDetail.ScanDetail.TotalKeys + } + ssElement.sumProcessedKeys += sei.ExecDetail.ScanDetail.ProcessedKeys + if sei.ExecDetail.ScanDetail.ProcessedKeys > ssElement.maxProcessedKeys { + ssElement.maxProcessedKeys = sei.ExecDetail.ScanDetail.ProcessedKeys + } } // txn diff --git a/util/stmtsummary/statement_summary_test.go b/util/stmtsummary/statement_summary_test.go index 51adb3d6f22a9..0bc7b3ab0cedb 100644 --- a/util/stmtsummary/statement_summary_test.go +++ b/util/stmtsummary/statement_summary_test.go @@ -95,16 +95,16 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { maxCopProcessAddress: stmtExecInfo1.CopTasks.MaxProcessAddress, maxCopWaitTime: stmtExecInfo1.CopTasks.MaxWaitTime, maxCopWaitAddress: stmtExecInfo1.CopTasks.MaxWaitAddress, - sumProcessTime: stmtExecInfo1.ExecDetail.ProcessTime, - maxProcessTime: stmtExecInfo1.ExecDetail.ProcessTime, - sumWaitTime: stmtExecInfo1.ExecDetail.WaitTime, - maxWaitTime: stmtExecInfo1.ExecDetail.WaitTime, + sumProcessTime: stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime, + maxProcessTime: stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime, + sumWaitTime: stmtExecInfo1.ExecDetail.TimeDetail.WaitTime, + maxWaitTime: stmtExecInfo1.ExecDetail.TimeDetail.WaitTime, sumBackoffTime: stmtExecInfo1.ExecDetail.BackoffTime, maxBackoffTime: stmtExecInfo1.ExecDetail.BackoffTime, - sumTotalKeys: stmtExecInfo1.ExecDetail.TotalKeys, - maxTotalKeys: stmtExecInfo1.ExecDetail.TotalKeys, - sumProcessedKeys: stmtExecInfo1.ExecDetail.ProcessedKeys, - maxProcessedKeys: stmtExecInfo1.ExecDetail.ProcessedKeys, + sumTotalKeys: stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + maxTotalKeys: stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + sumProcessedKeys: stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, + maxProcessedKeys: stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, sumGetCommitTsTime: stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime, maxGetCommitTsTime: stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime, sumPrewriteTime: stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime, @@ -176,12 +176,8 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "202", - ProcessTime: 1500, - WaitTime: 150, BackoffTime: 180, RequestCount: 20, - TotalKeys: 6000, - ProcessedKeys: 1500, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 500, PrewriteTime: 50000, @@ -200,6 +196,14 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { PrewriteRegionNum: 100, TxnRetry: 10, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 6000, + ProcessedKeys: 1500, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 1500, + WaitTime: 150, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -224,16 +228,16 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { expectedSummaryElement.maxCopProcessAddress = stmtExecInfo2.CopTasks.MaxProcessAddress expectedSummaryElement.maxCopWaitTime = stmtExecInfo2.CopTasks.MaxWaitTime expectedSummaryElement.maxCopWaitAddress = stmtExecInfo2.CopTasks.MaxWaitAddress - expectedSummaryElement.sumProcessTime += stmtExecInfo2.ExecDetail.ProcessTime - expectedSummaryElement.maxProcessTime = stmtExecInfo2.ExecDetail.ProcessTime - expectedSummaryElement.sumWaitTime += stmtExecInfo2.ExecDetail.WaitTime - expectedSummaryElement.maxWaitTime = stmtExecInfo2.ExecDetail.WaitTime + expectedSummaryElement.sumProcessTime += stmtExecInfo2.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.maxProcessTime = stmtExecInfo2.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.sumWaitTime += stmtExecInfo2.ExecDetail.TimeDetail.WaitTime + expectedSummaryElement.maxWaitTime = stmtExecInfo2.ExecDetail.TimeDetail.WaitTime expectedSummaryElement.sumBackoffTime += stmtExecInfo2.ExecDetail.BackoffTime expectedSummaryElement.maxBackoffTime = stmtExecInfo2.ExecDetail.BackoffTime - expectedSummaryElement.sumTotalKeys += stmtExecInfo2.ExecDetail.TotalKeys - expectedSummaryElement.maxTotalKeys = stmtExecInfo2.ExecDetail.TotalKeys - expectedSummaryElement.sumProcessedKeys += stmtExecInfo2.ExecDetail.ProcessedKeys - expectedSummaryElement.maxProcessedKeys = stmtExecInfo2.ExecDetail.ProcessedKeys + expectedSummaryElement.sumTotalKeys += stmtExecInfo2.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.maxTotalKeys = stmtExecInfo2.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.sumProcessedKeys += stmtExecInfo2.ExecDetail.ScanDetail.ProcessedKeys + expectedSummaryElement.maxProcessedKeys = stmtExecInfo2.ExecDetail.ScanDetail.ProcessedKeys expectedSummaryElement.sumGetCommitTsTime += stmtExecInfo2.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.maxGetCommitTsTime = stmtExecInfo2.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.sumPrewriteTime += stmtExecInfo2.ExecDetail.CommitDetail.PrewriteTime @@ -294,12 +298,8 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "302", - ProcessTime: 150, - WaitTime: 15, BackoffTime: 18, RequestCount: 2, - TotalKeys: 600, - ProcessedKeys: 150, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 50, PrewriteTime: 5000, @@ -318,6 +318,14 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { PrewriteRegionNum: 10, TxnRetry: 1, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 600, + ProcessedKeys: 150, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 150, + WaitTime: 15, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -336,11 +344,11 @@ func (s *testStmtSummarySuite) TestAddStatement(c *C) { expectedSummaryElement.sumParseLatency += stmtExecInfo3.ParseLatency expectedSummaryElement.sumCompileLatency += stmtExecInfo3.CompileLatency expectedSummaryElement.sumNumCopTasks += int64(stmtExecInfo3.CopTasks.NumCopTasks) - expectedSummaryElement.sumProcessTime += stmtExecInfo3.ExecDetail.ProcessTime - expectedSummaryElement.sumWaitTime += stmtExecInfo3.ExecDetail.WaitTime + expectedSummaryElement.sumProcessTime += stmtExecInfo3.ExecDetail.TimeDetail.ProcessTime + expectedSummaryElement.sumWaitTime += stmtExecInfo3.ExecDetail.TimeDetail.WaitTime expectedSummaryElement.sumBackoffTime += stmtExecInfo3.ExecDetail.BackoffTime - expectedSummaryElement.sumTotalKeys += stmtExecInfo3.ExecDetail.TotalKeys - expectedSummaryElement.sumProcessedKeys += stmtExecInfo3.ExecDetail.ProcessedKeys + expectedSummaryElement.sumTotalKeys += stmtExecInfo3.ExecDetail.ScanDetail.TotalKeys + expectedSummaryElement.sumProcessedKeys += stmtExecInfo3.ExecDetail.ScanDetail.ProcessedKeys expectedSummaryElement.sumGetCommitTsTime += stmtExecInfo3.ExecDetail.CommitDetail.GetCommitTsTime expectedSummaryElement.sumPrewriteTime += stmtExecInfo3.ExecDetail.CommitDetail.PrewriteTime expectedSummaryElement.sumCommitTime += stmtExecInfo3.ExecDetail.CommitDetail.CommitTime @@ -541,12 +549,8 @@ func generateAnyExecInfo() *StmtExecInfo { }, ExecDetail: &execdetails.ExecDetails{ CalleeAddress: "129", - ProcessTime: 500, - WaitTime: 50, BackoffTime: 80, RequestCount: 10, - TotalKeys: 1000, - ProcessedKeys: 500, CommitDetail: &execdetails.CommitDetails{ GetCommitTsTime: 100, PrewriteTime: 10000, @@ -565,6 +569,14 @@ func generateAnyExecInfo() *StmtExecInfo { PrewriteRegionNum: 20, TxnRetry: 2, }, + ScanDetail: &execdetails.ScanDetail{ + TotalKeys: 1000, + ProcessedKeys: 500, + }, + TimeDetail: execdetails.TimeDetail{ + ProcessTime: 500, + WaitTime: 50, + }, }, StmtCtx: &stmtctx.StatementContext{ StmtType: "Select", @@ -600,10 +612,10 @@ func (s *testStmtSummarySuite) TestToDatum(c *C) { int64(stmtExecInfo1.ParseLatency), int64(stmtExecInfo1.ParseLatency), int64(stmtExecInfo1.CompileLatency), int64(stmtExecInfo1.CompileLatency), stmtExecInfo1.CopTasks.NumCopTasks, int64(stmtExecInfo1.CopTasks.MaxProcessTime), stmtExecInfo1.CopTasks.MaxProcessAddress, int64(stmtExecInfo1.CopTasks.MaxWaitTime), - stmtExecInfo1.CopTasks.MaxWaitAddress, int64(stmtExecInfo1.ExecDetail.ProcessTime), int64(stmtExecInfo1.ExecDetail.ProcessTime), - int64(stmtExecInfo1.ExecDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.BackoffTime), - int64(stmtExecInfo1.ExecDetail.BackoffTime), stmtExecInfo1.ExecDetail.TotalKeys, stmtExecInfo1.ExecDetail.TotalKeys, - stmtExecInfo1.ExecDetail.ProcessedKeys, stmtExecInfo1.ExecDetail.ProcessedKeys, + stmtExecInfo1.CopTasks.MaxWaitAddress, int64(stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime), int64(stmtExecInfo1.ExecDetail.TimeDetail.ProcessTime), + int64(stmtExecInfo1.ExecDetail.TimeDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.TimeDetail.WaitTime), int64(stmtExecInfo1.ExecDetail.BackoffTime), + int64(stmtExecInfo1.ExecDetail.BackoffTime), stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, stmtExecInfo1.ExecDetail.ScanDetail.TotalKeys, + stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, stmtExecInfo1.ExecDetail.ScanDetail.ProcessedKeys, int64(stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.PrewriteTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.CommitTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.CommitTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime), int64(stmtExecInfo1.ExecDetail.CommitDetail.GetCommitTsTime), From 06ec53c2b0ab67a78a4b17f15d8924c772b754e4 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 10 May 2021 23:17:37 +0800 Subject: [PATCH 76/85] statistics: skip reading mysql.stats_histograms if cached stats is up-to-date (#24175) (#24352) --- statistics/handle/handle.go | 4 ++++ statistics/handle/handle_test.go | 21 +++++++++++++++++++++ statistics/table.go | 13 ++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 2d8915a8e0478..9dce7a5c45a7c 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -214,6 +214,9 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { continue } tableInfo := table.Meta() + if oldTbl, ok := oldCache.tables[physicalID]; ok && oldTbl.Version >= version && tableInfo.UpdateTS == oldTbl.TblInfoUpdateTS { + continue + } tbl, err := h.tableStatsFromStorage(tableInfo, physicalID, false, 0) // Error is not nil may mean that there are some ddl changes on this table, we will not update it. if err != nil { @@ -228,6 +231,7 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { tbl.Count = count tbl.ModifyCount = modifyCount tbl.Name = getFullTableName(is, tableInfo) + tbl.TblInfoUpdateTS = tableInfo.UpdateTS tables = append(tables, tbl) } h.updateStatsCache(oldCache.update(tables, deletedTableIDs, lastVersion)) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 32b491ce7ef35..c8f486efda38f 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -555,3 +555,24 @@ func (s *testStatsSuite) TestCorrelation(c *C) { c.Assert(len(result.Rows()), Equals, 1) c.Assert(result.Rows()[0][9], Equals, "0") } + +func (s *testStatsSuite) TestStatsCacheUpdateSkip(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + do := s.do + h := do.StatsHandle() + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + testKit.MustExec("insert into t values(1, 2)") + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + testKit.MustExec("analyze table t") + is := do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tableInfo := tbl.Meta() + statsTbl1 := h.GetTableStats(tableInfo) + c.Assert(statsTbl1.Pseudo, IsFalse) + h.Update(is) + statsTbl2 := h.GetTableStats(tableInfo) + c.Assert(statsTbl1, Equals, statsTbl2) +} diff --git a/statistics/table.go b/statistics/table.go index 518da999b6695..59592791af12e 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -61,6 +61,12 @@ type Table struct { HistColl Version uint64 Name string + // TblInfoUpdateTS is the UpdateTS of the TableInfo used when filling this struct. + // It is the schema version of the corresponding table. It is used to skip redundant + // loading of stats, i.e, if the cached stats is already update-to-date with mysql.stats_xxx tables, + // and the schema of the table does not change, we don't need to load the stats for this + // table again. + TblInfoUpdateTS uint64 } // HistColl is a collection of histogram. It collects enough information for plan to calculate the selectivity. @@ -99,9 +105,10 @@ func (t *Table) Copy() *Table { newHistColl.Indices[id] = idx } nt := &Table{ - HistColl: newHistColl, - Version: t.Version, - Name: t.Name, + HistColl: newHistColl, + Version: t.Version, + Name: t.Name, + TblInfoUpdateTS: t.TblInfoUpdateTS, } return nt } From b130ae38dd91f6f98fe696bd54e9070fc789c49b Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 01:35:37 +0800 Subject: [PATCH 77/85] executor: make column default value being aware of NO_ZERO_IN_DATE (#24174) (#24185) --- executor/executor.go | 3 ++- executor/executor_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/executor/executor.go b/executor/executor.go index 7192a1ecc45d9..20cccafac6809 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1699,7 +1699,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate sc.Priority = stmt.Priority case *ast.CreateTableStmt, *ast.AlterTableStmt: - // Make sure the sql_mode is strict when checking column default value. + sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() + sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || sc.AllowInvalidDate case *ast.LoadDataStmt: sc.DupKeyAsWarning = true sc.BadNullAsWarning = true diff --git a/executor/executor_test.go b/executor/executor_test.go index d3027057f3ff5..35457203eb736 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6550,6 +6550,32 @@ func (s *testSuite) TestZeroDateTimeCompatibility(c *C) { } } +// https://github.com/pingcap/tidb/issues/24165. +func (s *testSuite) TestInvalidDateValueInCreateTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("set @@sql_mode='STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE';") + tk.MustGetErrCode("create table t (a datetime default '2999-00-00 00:00:00');", mysql.ErrInvalidDefault) + tk.MustGetErrCode("create table t (a datetime default '2999-02-30 00:00:00');", mysql.ErrInvalidDefault) + tk.MustExec("create table t (a datetime);") + tk.MustGetErrCode("alter table t modify column a datetime default '2999-00-00 00:00:00';", mysql.ErrInvalidDefault) + tk.MustExec("drop table if exists t;") + + tk.MustExec("set @@sql_mode = (select replace(@@sql_mode,'NO_ZERO_IN_DATE',''));") + tk.MustExec("set @@sql_mode = (select replace(@@sql_mode,'NO_ZERO_DATE',''));") + tk.MustExec("set @@sql_mode=(select concat(@@sql_mode, ',ALLOW_INVALID_DATES'));") + // Test create table with zero datetime as a default value. + tk.MustExec("create table t (a datetime default '2999-00-00 00:00:00');") + tk.MustExec("drop table if exists t;") + // Test create table with invalid datetime(02-30) as a default value. + tk.MustExec("create table t (a datetime default '2999-02-30 00:00:00');") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a datetime);") + tk.MustExec("alter table t modify column a datetime default '2999-00-00 00:00:00';") + tk.MustExec("drop table if exists t;") +} + func (s *testSuite) TestOOMActionPriority(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") From f6135765eb24268b4d98ee9585f979e14e6d475b Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 02:55:39 +0800 Subject: [PATCH 78/85] planner, type: remove the prefix 0 in the bit array when we get the BinaryLiteral (#23523) (#23655) --- planner/core/point_get_plan_test.go | 21 +++++++++++++++++++++ types/datum.go | 24 ++++++++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index b6140f1ada2e7..840ce7ce15735 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -522,3 +522,24 @@ func (s *testPointGetSuite) TestCBOShouldNotUsePointGet(c *C) { res.Check(testkit.Rows(output[i].Res...)) } } + +func (s *testPointGetSuite) TestIssue23511(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("CREATE TABLE `t1` (`COL1` bit(11) NOT NULL,PRIMARY KEY (`COL1`));") + tk.MustExec("CREATE TABLE `t2` (`COL1` bit(11) NOT NULL);") + tk.MustExec("insert into t1 values(b'00000111001'), (b'00000000000');") + tk.MustExec("insert into t2 values(b'00000111001');") + tk.MustQuery("select * from t1 where col1 = 0x39;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x39;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t1 where col1 = 0x00;").Check(testkit.Rows("\x00\x00")) + tk.MustQuery("select * from t1 where col1 = 0x0000;").Check(testkit.Rows("\x00\x00")) + tk.MustQuery("explain select * from t1 where col1 = 0x39;"). + Check(testkit.Rows("Point_Get_1 1.00 root table:t1, index:PRIMARY(COL1) ")) + tk.MustQuery("select * from t1 where col1 = 0x0039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x0039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t1 where col1 = 0x000039;").Check(testkit.Rows("\x009")) + tk.MustQuery("select * from t2 where col1 = 0x000039;").Check(testkit.Rows("\x009")) + tk.MustExec("drop table t1, t2;") +} diff --git a/types/datum.go b/types/datum.go index 7512ff60f0d7f..246c87ac2eba2 100644 --- a/types/datum.go +++ b/types/datum.go @@ -238,6 +238,22 @@ func (d *Datum) SetMinNotNull() { d.x = nil } +// GetBinaryLiteral4Cmp gets Bit value, and remove it's prefix 0 for comparison. +func (d *Datum) GetBinaryLiteral4Cmp() BinaryLiteral { + bitLen := len(d.b) + if bitLen == 0 { + return d.b + } + for i := 0; i < bitLen; i++ { + // Remove the prefix 0 in the bit array. + if d.b[i] != 0 { + return d.b[i:] + } + } + // The result is 0x000...00, we just the return 0x00. + return d.b[bitLen-1:] +} + // GetBinaryLiteral gets Bit value func (d *Datum) GetBinaryLiteral() BinaryLiteral { return d.b @@ -573,7 +589,7 @@ func (d *Datum) CompareDatum(sc *stmtctx.StatementContext, ad *Datum) (int, erro case KindMysqlEnum: return d.compareMysqlEnum(sc, ad.GetMysqlEnum()) case KindBinaryLiteral, KindMysqlBit: - return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral()) + return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral4Cmp()) case KindMysqlSet: return d.compareMysqlSet(sc, ad.GetMysqlSet()) case KindMysqlJSON: @@ -642,7 +658,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er fVal := d.GetMysqlEnum().ToNumber() return CompareFloat64(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt(sc) + val, err := d.GetBinaryLiteral4Cmp().ToInt(sc) fVal := float64(val) return CompareFloat64(fVal, f), errors.Trace(err) case KindMysqlSet: @@ -679,7 +695,7 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, retCollati case KindMysqlEnum: return CompareString(d.GetMysqlEnum().String(), s, d.collation), nil case KindBinaryLiteral, KindMysqlBit: - return CompareString(d.GetBinaryLiteral().ToString(), s, d.collation), nil + return CompareString(d.GetBinaryLiteral4Cmp().ToString(), s, d.collation), nil default: fVal, err := StrToFloat(sc, s, false) if err != nil { @@ -753,7 +769,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter case KindString, KindBytes: return CompareString(d.GetString(), b.ToString(), d.collation), nil case KindBinaryLiteral, KindMysqlBit: - return CompareString(d.GetBinaryLiteral().ToString(), b.ToString(), d.collation), nil + return CompareString(d.GetBinaryLiteral4Cmp().ToString(), b.ToString(), d.collation), nil default: val, err := b.ToInt(sc) if err != nil { From 60b89fdb112a37e9cdee81c2b10f9da18ee0a511 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 04:15:39 +0800 Subject: [PATCH 79/85] planner: fix the panic when we calculate the partition range (#23651) (#23689) --- planner/core/partition_pruner_test.go | 66 ++++++++++++++++++++++++ planner/core/rule_partition_processor.go | 4 ++ 2 files changed, 70 insertions(+) create mode 100644 planner/core/partition_pruner_test.go diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go new file mode 100644 index 0000000000000..0bc478e6e742a --- /dev/null +++ b/planner/core/partition_pruner_test.go @@ -0,0 +1,66 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" +) + +var _ = Suite(&testPartitionPruneSuit{}) + +type testPartitionPruneSuit struct { + store kv.Storage + dom *domain.Domain + ctx sessionctx.Context + testData testutil.TestData +} + +func (s *testPartitionPruneSuit) cleanEnv(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test_partition") + r := tk.MustQuery("show tables") + for _, tb := range r.Rows() { + tableName := tb[0] + tk.MustExec(fmt.Sprintf("drop table %v", tableName)) + } +} + +func (s *testPartitionPruneSuit) SetUpSuite(c *C) { + var err error + s.store, s.dom, err = newStoreWithBootstrap() + c.Assert(err, IsNil) + s.ctx = mock.NewContext() +} + +func (s *testPartitionPruneSuit) TearDownSuite(c *C) { + s.dom.Close() + s.store.Close() +} + +func (s *testPartitionPruneSuit) TestIssue23622(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t2 (a int, b int) partition by range (a) (partition p0 values less than (0), partition p1 values less than (5));") + tk.MustExec("insert into t2(a) values (-1), (1);") + tk.MustQuery("select * from t2 where a > 10 or b is NULL order by a;").Check(testkit.Rows("-1 ", "1 ")) +} diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 83536239f47e7..be40d793e77bb 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -266,6 +266,10 @@ func (or partitionRangeOR) union(x partitionRangeOR) partitionRangeOR { } func (or partitionRangeOR) simplify() partitionRangeOR { + // if the length of the `or` is zero. We should return early. + if len(or) == 0 { + return or + } // Make the ranges order by start. sort.Sort(or) sorted := or From 59cd283735643ac68c9b268a3975419bef06c5b0 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 06:13:39 +0800 Subject: [PATCH 80/85] executor: fix projection executor panic and add failpoint test (#24231) (#24340) --- executor/aggregate.go | 67 +++++++++++++++++++++++++-------------- executor/executor.go | 11 +++++-- executor/executor_test.go | 37 +++++++++++++++++++++ executor/projection.go | 10 +++++- util/memory/tracker.go | 6 ++++ 5 files changed, 105 insertions(+), 26 deletions(-) diff --git a/executor/aggregate.go b/executor/aggregate.go index bcbb9495470d7..34230f1eb63c2 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -166,9 +166,10 @@ type HashAggExec struct { isChildReturnEmpty bool // After we support parallel execution for aggregation functions with distinct, // we can remove this attribute. - isUnparallelExec bool - prepared bool - executed bool + isUnparallelExec bool + parallelExecInitialized bool + prepared bool + executed bool memTracker *memory.Tracker // track memory usage. } @@ -204,36 +205,42 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, // Close implements the Executor Close interface. func (e *HashAggExec) Close() error { if e.isUnparallelExec { - e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupSet = nil e.partialResultMap = nil + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } return e.baseExecutor.Close() } - // `Close` may be called after `Open` without calling `Next` in test. - if !e.prepared { - close(e.inputCh) + if e.parallelExecInitialized { + // `Close` may be called after `Open` without calling `Next` in test. + if !e.prepared { + close(e.inputCh) + for _, ch := range e.partialOutputChs { + close(ch) + } + for _, ch := range e.partialInputChs { + close(ch) + } + close(e.finalOutputCh) + } + close(e.finishCh) for _, ch := range e.partialOutputChs { - close(ch) + for range ch { + } } for _, ch := range e.partialInputChs { - close(ch) + for range ch { + } } - close(e.finalOutputCh) - } - close(e.finishCh) - for _, ch := range e.partialOutputChs { - for range ch { + for range e.finalOutputCh { } - } - for _, ch := range e.partialInputChs { - for chk := range ch { - e.memTracker.Consume(-chk.MemoryUsage()) + e.executed = false + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) } } - for range e.finalOutputCh { - } - e.executed = false if e.runtimeStats != nil { var partialConcurrency, finalConcurrency int @@ -258,6 +265,11 @@ func (e *HashAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) + } + }) e.prepared = false e.memTracker = memory.NewTracker(e.id, -1) @@ -340,6 +352,8 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { } e.finalWorkers[i].finalResultHolderCh <- newFirstChunk(e) } + + e.parallelExecInitialized = true } func (w *HashAggPartialWorker) getChildInput() bool { @@ -838,6 +852,11 @@ func (e *StreamAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) + } + }) e.childResult = newFirstChunk(e.children[0]) e.executed = false e.isChildReturnEmpty = true @@ -858,8 +877,10 @@ func (e *StreamAggExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *StreamAggExec) Close() error { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } e.groupChecker.reset() return e.baseExecutor.Close() } diff --git a/executor/executor.go b/executor/executor.go index 20cccafac6809..818c7099634ba 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1188,6 +1188,11 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) + } + }) return e.open(ctx) } @@ -1207,8 +1212,10 @@ func (e *SelectionExec) open(ctx context.Context) error { // Close implements plannercore.Plan Close interface. func (e *SelectionExec) Close() error { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } e.selected = nil return e.baseExecutor.Close() } diff --git a/executor/executor_test.go b/executor/executor_test.go index 35457203eb736..a7a0a149559b4 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6763,3 +6763,40 @@ func (s *testSuiteP1) TestIssue22941(c *C) { rs = tk.MustQuery(`SELECT bmp.mpid, bmp.mpid IS NULL,bmp.mpid IS NOT NULL FROM m c LEFT JOIN mp bmp ON c.mid = bmp.mid WHERE c.ParentId = '0'`) rs.Check(testkit.Rows(" 1 0")) } + +func (s *testSerialSuite1) TestIssue24210(c *C) { + tk := testkit.NewTestKit(c, s.store) + + // for ProjectionExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockProjectionExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err := tk.Exec("select a from (select 1 as a, 2 as b) t") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock ProjectionExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockProjectionExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for HashAggExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockHashAggExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select sum(a) from (select 1 as a, 2 as b) t group by b") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock HashAggExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockHashAggExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for StreamAggExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockStreamAggExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select sum(a) from (select 1 as a, 2 as b) t") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock StreamAggExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockStreamAggExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + + // for SelectionExec + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockSelectionExecBaseExecutorOpenReturnedError", `return(true)`), IsNil) + _, err = tk.Exec("select * from (select 1 as a) t where a > 0") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "mock SelectionExec.baseExecutor.Open returned error") + err = failpoint.Disable("github.com/pingcap/tidb/executor/mockSelectionExecBaseExecutorOpenReturnedError") + c.Assert(err, IsNil) + +} diff --git a/executor/projection.go b/executor/projection.go index c36d435ab231a..776502d07554f 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util" @@ -85,6 +86,11 @@ func (e *ProjectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) + } + }) return e.open(ctx) } @@ -290,7 +296,9 @@ func (e *ProjectionExec) drainOutputCh(ch chan *projectionOutput) { // Close implements the Executor Close interface. func (e *ProjectionExec) Close() error { - if e.isUnparallelExec() { + // if e.baseExecutor.Open returns error, e.childResult will be nil, see https://github.com/pingcap/tidb/issues/24210 + // for more information + if e.isUnparallelExec() && e.childResult != nil { e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil } diff --git a/util/memory/tracker.go b/util/memory/tracker.go index a8736376ecea1..e630a2b730942 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -430,6 +430,12 @@ func (t *Tracker) DetachFromGlobalTracker() { t.parent = nil } +// ReplaceBytesUsed replace bytesConsume for the tracker +func (t *Tracker) ReplaceBytesUsed(bytes int64) { + t.Consume(-t.BytesConsumed()) + t.Consume(bytes) +} + const ( // LabelForSQLText represents the label of the SQL Text LabelForSQLText int = -1 From b4c66acf09233985ec6b8d8909944f63c32b68d9 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 10:49:39 +0800 Subject: [PATCH 81/85] tablecodec: fix text type decode for old row format (#23751) (#23772) --- tablecodec/tablecodec.go | 6 +++--- tablecodec/tablecodec_test.go | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 9179955af674f..7c480d4ab8751 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -512,11 +512,11 @@ func unflatten(datum types.Datum, ft *types.FieldType, loc *time.Location) (type case mysql.TypeFloat: datum.SetFloat32(float32(datum.GetFloat64())) return datum, nil - case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, + mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: datum.SetString(datum.GetString(), ft.Collate) case mysql.TypeTiny, mysql.TypeShort, mysql.TypeYear, mysql.TypeInt24, - mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeTinyBlob, - mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: + mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble: return datum, nil case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: t := types.NewTime(types.ZeroCoreTime, ft.Tp, int8(ft.Decimal)) diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index b56c3cd6b9901..3a27b01ae8be9 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -175,6 +175,16 @@ func (s *testTableCodecSuite) TestUnflattenDatums(c *C) { cmp, err := input[0].CompareDatum(sc, &output[0]) c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) + + input = []types.Datum{types.NewCollationStringDatum("aaa", "utf8mb4_unicode_ci", 0)} + tps = []*types.FieldType{types.NewFieldType(mysql.TypeBlob)} + tps[0].Collate = "utf8mb4_unicode_ci" + output, err = UnflattenDatums(input, tps, sc.TimeZone) + c.Assert(err, IsNil) + cmp, err = input[0].CompareDatum(sc, &output[0]) + c.Assert(err, IsNil) + c.Assert(cmp, Equals, 0) + c.Assert(output[0].Collation(), Equals, "utf8mb4_unicode_ci") } func (s *testTableCodecSuite) TestTimeCodec(c *C) { From 62dace6a5b76b273344cb1c98c70947322ebb3fc Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 14:15:38 +0800 Subject: [PATCH 82/85] server: set connection to TCP socket when unix and TCP used (#23463) (#23513) --- server/server.go | 5 ++++- server/tidb_test.go | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 8f34ef9ef61a7..ee9ab43f7eadc 100644 --- a/server/server.go +++ b/server/server.go @@ -172,8 +172,11 @@ func (s *Server) newConn(conn net.Conn) *clientConn { return cc } +// isUnixSocket should ideally be a function of clientConnection! +// But currently since unix-socket connections are forwarded to TCP when the server listens on both, it can really only be accurate on a server-level. +// If the server is listening on both, it *must* return FALSE for remote-host authentication to be performed correctly. See #23460. func (s *Server) isUnixSocket() bool { - return s.cfg.Socket != "" + return s.cfg.Socket != "" && s.cfg.Port == 0 } func (s *Server) forwardUnixSocketToTCP() { diff --git a/server/tidb_test.go b/server/tidb_test.go index acb00f06348a1..5bafaa137c382 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -248,6 +248,7 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsFalse) // If listening on tcp-only, return FALSE // https connection should work. ts.runTestStatusAPI(c) @@ -348,6 +349,7 @@ func (ts *tidbTestSuite) TestSocketForwarding(c *C) { cli.port = getPortFromTCPAddr(server.listener.Addr()) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsFalse) // If listening on both, return FALSE defer server.Close() cli.runTestRegression(c, func(config *mysql.Config) { @@ -371,6 +373,7 @@ func (ts *tidbTestSuite) TestSocket(c *C) { c.Assert(err, IsNil) go server.Run() time.Sleep(time.Millisecond * 100) + c.Assert(server.isUnixSocket(), IsTrue) // If listening on socket-only, return TRUE defer server.Close() //a fake server client, config is override, just used to run tests From 6c1a625592a1ae52b1621d4760f7182d8a6136f3 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 16:01:39 +0800 Subject: [PATCH 83/85] planner: fix wrong TableDual plans caused by comparing Binary and Bytes incorrectly (#23860) (#23917) --- planner/core/integration_test.go | 9 +++++++++ types/datum.go | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 98fc2eda92bc2..2b0e487958e2b 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -2027,3 +2027,12 @@ func (s *testIntegrationSuite) TestGetVarExprWithBitLiteral(c *C) { tk.MustExec("set @a = 0b11000100110101;") tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1")) } + +func (s *testIntegrationSuite) TestIssue23846(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a varbinary(10),UNIQUE KEY(a))") + tk.MustExec("insert into t values(0x00A4EEF4FA55D6706ED5)") + tk.MustQuery("select * from t where a=0x00A4EEF4FA55D6706ED5").Check(testkit.Rows("\x00\xa4\xee\xf4\xfaU\xd6pn\xd5")) // not empty +} diff --git a/types/datum.go b/types/datum.go index 246c87ac2eba2..f0d14c7520eae 100644 --- a/types/datum.go +++ b/types/datum.go @@ -767,7 +767,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter case KindMaxValue: return 1, nil case KindString, KindBytes: - return CompareString(d.GetString(), b.ToString(), d.collation), nil + fallthrough // in this case, d is converted to Binary and then compared with b case KindBinaryLiteral, KindMysqlBit: return CompareString(d.GetBinaryLiteral4Cmp().ToString(), b.ToString(), d.collation), nil default: From 7059d9af18e3a068c2cf79d8442fb324e4361e2d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 16:37:40 +0800 Subject: [PATCH 84/85] *: collect transaction write duration/throughput metrics for SLI/SLO (#23462) (#23658) --- executor/adapter.go | 8 +++ executor/executor_test.go | 88 +++++++++++++++++++++++++++- executor/insert_common.go | 3 + metrics/metrics.go | 2 + metrics/sli.go | 40 +++++++++++++ server/conn.go | 32 +++++----- server/driver.go | 4 ++ server/driver_tidb.go | 6 ++ session/session.go | 6 ++ session/txn.go | 3 + sessionctx/context.go | 3 + util/mock/context.go | 6 ++ util/sli/sli.go | 119 ++++++++++++++++++++++++++++++++++++++ 13 files changed, 305 insertions(+), 15 deletions(-) create mode 100644 metrics/sli.go create mode 100644 util/sli/sli.go diff --git a/executor/adapter.go b/executor/adapter.go index 1831f3c4d02ab..5166cca747b41 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -827,6 +827,14 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, succ bool, hasMoreResults boo } sessVars.StmtCtx.RuntimeStatsColl.RegisterStats(a.Plan.ID(), statsWithCommit) } + // Record related SLI metrics. + if execDetail.CommitDetail != nil && execDetail.CommitDetail.WriteSize > 0 { + a.Ctx.GetTxnWriteThroughputSLI().AddTxnWriteSize(execDetail.CommitDetail.WriteSize, execDetail.CommitDetail.WriteKeys) + } + if execDetail.ScanDetail != nil && execDetail.ScanDetail.ProcessedKeys > 0 && sessVars.StmtCtx.AffectedRows() > 0 { + // Only record the read keys in write statement which affect row more than 0. + a.Ctx.GetTxnWriteThroughputSLI().AddReadKeys(execDetail.ScanDetail.ProcessedKeys) + } // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. a.LogSlowQuery(txnTS, succ, hasMoreResults) a.SummaryStmt(succ) diff --git a/executor/executor_test.go b/executor/executor_test.go index a7a0a149559b4..c2ccaf7cc2044 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6764,6 +6764,93 @@ func (s *testSuiteP1) TestIssue22941(c *C) { rs.Check(testkit.Rows(" 1 0")) } +func (s *testSerialSuite2) TestTxnWriteThroughputSLI(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int key, b int)") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput", "return(true)"), IsNil) + defer func() { + err := failpoint.Disable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput") + c.Assert(err, IsNil) + }() + + mustExec := func(sql string) { + tk.MustExec(sql) + tk.Se.GetTxnWriteThroughputSLI().FinishExecuteStmt(time.Second, tk.Se.AffectedRows(), tk.Se.GetSessionVars().InTxn()) + } + errExec := func(sql string) { + _, err := tk.Exec(sql) + c.Assert(err, NotNil) + tk.Se.GetTxnWriteThroughputSLI().FinishExecuteStmt(time.Second, tk.Se.AffectedRows(), tk.Se.GetSessionVars().InTxn()) + } + + // Test insert in small txn + mustExec("insert into t values (1,3),(2,4)") + writeSLI := tk.Se.GetTxnWriteThroughputSLI() + c.Assert(writeSLI.IsInvalid(), Equals, false) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 2, writeSize: 58, readKeys: 0, writeKeys: 2, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test insert ... select ... from + mustExec("insert into t select b, a from t") + c.Assert(writeSLI.IsInvalid(), Equals, true) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: true, affectRow: 2, writeSize: 58, readKeys: 0, writeKeys: 2, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test for delete + mustExec("delete from t") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 4, writeSize: 76, readKeys: 0, writeKeys: 4, writeTime: 1s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test insert not in small txn + mustExec("begin") + for i := 0; i < 20; i++ { + mustExec(fmt.Sprintf("insert into t values (%v,%v)", i, i)) + c.Assert(writeSLI.IsSmallTxn(), Equals, true) + } + // The statement which affect rows is 0 shouldn't record into time. + mustExec("select count(*) from t") + mustExec("select * from t") + mustExec("insert into t values (20,20)") + c.Assert(writeSLI.IsSmallTxn(), Equals, false) + mustExec("commit") + c.Assert(writeSLI.IsInvalid(), Equals, false) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 21, writeSize: 609, readKeys: 0, writeKeys: 21, writeTime: 22s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test invalid when transaction has replace ... select ... from ... statement. + mustExec("delete from t") + tk.Se.GetTxnWriteThroughputSLI().Reset() + mustExec("begin") + mustExec("insert into t values (1,3),(2,4)") + mustExec("replace into t select b, a from t") + mustExec("commit") + c.Assert(writeSLI.IsInvalid(), Equals, true) + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: true, affectRow: 4, writeSize: 116, readKeys: 0, writeKeys: 4, writeTime: 3s") + tk.Se.GetTxnWriteThroughputSLI().Reset() + + // Test clean last failed transaction information. + err := failpoint.Disable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput") + c.Assert(err, IsNil) + mustExec("begin") + mustExec("insert into t values (1,3),(2,4)") + errExec("commit") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 0, writeSize: 0, readKeys: 0, writeKeys: 0, writeTime: 0s") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/util/sli/CheckTxnWriteThroughput", "return(true)"), IsNil) + mustExec("begin") + mustExec("insert into t values (5, 6)") + mustExec("commit") + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 1, writeSize: 29, readKeys: 0, writeKeys: 1, writeTime: 2s") + + // Test for reset + tk.Se.GetTxnWriteThroughputSLI().Reset() + c.Assert(tk.Se.GetTxnWriteThroughputSLI().String(), Equals, "invalid: false, affectRow: 0, writeSize: 0, readKeys: 0, writeKeys: 0, writeTime: 0s") +} + func (s *testSerialSuite1) TestIssue24210(c *C) { tk := testkit.NewTestKit(c, s.store) @@ -6798,5 +6885,4 @@ func (s *testSerialSuite1) TestIssue24210(c *C) { c.Assert(err.Error(), Equals, "mock SelectionExec.baseExecutor.Open returned error") err = failpoint.Disable("github.com/pingcap/tidb/executor/mockSelectionExecBaseExecutorOpenReturnedError") c.Assert(err, IsNil) - } diff --git a/executor/insert_common.go b/executor/insert_common.go index 7c4a59fa2b3f3..d6fbcfa74b353 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -425,6 +425,9 @@ func insertRowsFromSelect(ctx context.Context, base insertCommon) error { batchSize := sessVars.DMLBatchSize memUsageOfRows := int64(0) memTracker := e.memTracker + // In order to ensure the correctness of the `transaction write throughput` SLI statistics, + // just ignore the transaction which contain `insert|replace into ... select ... from ...` statement. + e.ctx.GetTxnWriteThroughputSLI().SetInvalid() for { err := Next(ctx, selectExec, chk) if err != nil { diff --git a/metrics/metrics.go b/metrics/metrics.go index 560557e723de8..75125d664d8d5 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -173,4 +173,6 @@ func RegisterMetrics() { prometheus.MustRegister(ServerInfo) prometheus.MustRegister(TokenGauge) prometheus.MustRegister(ConfigStatus) + prometheus.MustRegister(SmallTxnWriteDuration) + prometheus.MustRegister(TxnWriteThroughput) } diff --git a/metrics/sli.go b/metrics/sli.go new file mode 100644 index 0000000000000..2e926de099997 --- /dev/null +++ b/metrics/sli.go @@ -0,0 +1,40 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +var ( + // SmallTxnWriteDuration uses to collect small transaction write duration. + SmallTxnWriteDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "sli", + Name: "small_txn_write_duration_seconds", + Help: "Bucketed histogram of small transaction write time (s).", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 28), // 1ms ~ 74h + }) + + // TxnWriteThroughput uses to collect transaction write throughput which transaction is not small. + TxnWriteThroughput = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "sli", + Name: "txn_write_throughput", + Help: "Bucketed histogram of transaction write throughput (bytes/second).", + Buckets: prometheus.ExponentialBuckets(64, 1.3, 40), // 64 bytes/s ~ 2.3MB/s + }) +) diff --git a/server/conn.go b/server/conn.go index b0d46538958e9..e8687a8a54c28 100644 --- a/server/conn.go +++ b/server/conn.go @@ -879,35 +879,39 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { sqlType = stmtType } + cost := time.Since(startTime) + sessionVar := cc.ctx.GetSessionVars() + cc.ctx.GetTxnWriteThroughputSLI().FinishExecuteStmt(cost, cc.ctx.AffectedRows(), sessionVar.InTxn()) + switch sqlType { case "Use": - queryDurationHistogramUse.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramUse.Observe(cost.Seconds()) case "Show": - queryDurationHistogramShow.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramShow.Observe(cost.Seconds()) case "Begin": - queryDurationHistogramBegin.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramBegin.Observe(cost.Seconds()) case "Commit": - queryDurationHistogramCommit.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramCommit.Observe(cost.Seconds()) case "Rollback": - queryDurationHistogramRollback.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramRollback.Observe(cost.Seconds()) case "Insert": - queryDurationHistogramInsert.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramInsert.Observe(cost.Seconds()) case "Replace": - queryDurationHistogramReplace.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramReplace.Observe(cost.Seconds()) case "Delete": - queryDurationHistogramDelete.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramDelete.Observe(cost.Seconds()) case "Update": - queryDurationHistogramUpdate.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramUpdate.Observe(cost.Seconds()) case "Select": - queryDurationHistogramSelect.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramSelect.Observe(cost.Seconds()) case "Execute": - queryDurationHistogramExecute.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramExecute.Observe(cost.Seconds()) case "Set": - queryDurationHistogramSet.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramSet.Observe(cost.Seconds()) case metrics.LblGeneral: - queryDurationHistogramGeneral.Observe(time.Since(startTime).Seconds()) + queryDurationHistogramGeneral.Observe(cost.Seconds()) default: - metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(time.Since(startTime).Seconds()) + metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(cost.Seconds()) } } diff --git a/server/driver.go b/server/driver.go index 389f9f2700173..b8f95984849ac 100644 --- a/server/driver.go +++ b/server/driver.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sli" ) // IDriver opens IContext. @@ -100,6 +101,9 @@ type QueryCtx interface { SetCommandValue(command byte) SetSessionManager(util.SessionManager) + + // GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI. + GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI } // PreparedStatement is the interface to use a prepared statement. diff --git a/server/driver_tidb.go b/server/driver_tidb.go index e677310716668..4959bda9faeb9 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" ) @@ -367,6 +368,11 @@ func (tc *TiDBContext) GetSessionVars() *variable.SessionVars { return tc.session.GetSessionVars() } +// GetTxnWriteThroughputSLI implements QueryCtx GetTxnWriteThroughputSLI method. +func (tc *TiDBContext) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return tc.session.GetTxnWriteThroughputSLI() +} + type tidbResultSet struct { recordSet sqlexec.RecordSet columns []*ColumnInfo diff --git a/session/session.go b/session/session.go index 1413b50426914..45690f166d689 100644 --- a/session/session.go +++ b/session/session.go @@ -70,6 +70,7 @@ import ( "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tipb/go-binlog" @@ -2444,3 +2445,8 @@ func (s *session) recordOnTransactionExecution(err error, counter int, duration } } } + +// GetTxnWriteThroughputSLI implements the Context interface. +func (s *session) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &s.txn.writeSLI +} diff --git a/session/txn.go b/session/txn.go index b9ce2741f6d57..fca8d75ccaae5 100644 --- a/session/txn.go +++ b/session/txn.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tipb/go-binlog" "go.uber.org/zap" ) @@ -57,6 +58,8 @@ type TxnState struct { // If doNotCommit is not nil, Commit() will not commit the transaction. // doNotCommit flag may be set when StmtCommit fail. doNotCommit error + + writeSLI sli.TxnWriteThroughputSLI } func (st *TxnState) init() { diff --git a/sessionctx/context.go b/sessionctx/context.go index 167fc3bd79803..21cedc5cc2a22 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tipb/go-binlog" ) @@ -103,6 +104,8 @@ type Context interface { HasLockedTables() bool // PrepareTSFuture uses to prepare timestamp by future. PrepareTSFuture(ctx context.Context) + // GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI. + GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI } type basicCtxType int diff --git a/util/mock/context.go b/util/mock/context.go index e2461c7bc8446..6661500665f94 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/disk" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sli" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tipb/go-binlog" ) @@ -212,6 +213,11 @@ func (c *Context) GoCtx() context.Context { // StoreQueryFeedback stores the query feedback. func (c *Context) StoreQueryFeedback(_ interface{}) {} +// GetTxnWriteThroughputSLI implements the sessionctx.Context interface. +func (c *Context) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &sli.TxnWriteThroughputSLI{} +} + // StmtCommit implements the sessionctx.Context interface. func (c *Context) StmtCommit(tracker *memory.Tracker) error { return nil diff --git a/util/sli/sli.go b/util/sli/sli.go new file mode 100644 index 0000000000000..1c48f9a03a510 --- /dev/null +++ b/util/sli/sli.go @@ -0,0 +1,119 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sli + +import ( + "fmt" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/metrics" +) + +// TxnWriteThroughputSLI uses to report transaction write throughput metrics for SLI. +type TxnWriteThroughputSLI struct { + invalid bool + affectRow uint64 + writeSize int + readKeys int + writeKeys int + writeTime time.Duration +} + +// FinishExecuteStmt records the cost for write statement which affect rows more than 0. +// And report metrics when the transaction is committed. +func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow uint64, inTxn bool) { + if affectRow > 0 { + t.writeTime += cost + t.affectRow += affectRow + } + + // Currently not in transaction means the last transaction is finish, should report metrics and reset data. + if !inTxn { + if affectRow == 0 { + // AffectRows is 0 when statement is commit. + t.writeTime += cost + } + // Report metrics after commit this transaction. + t.reportMetric() + + // Skip reset for test. + failpoint.Inject("CheckTxnWriteThroughput", func() { + failpoint.Return() + }) + + // Reset for next transaction. + t.Reset() + } +} + +// AddReadKeys adds the read keys. +func (t *TxnWriteThroughputSLI) AddReadKeys(readKeys int64) { + t.readKeys += int(readKeys) +} + +// AddTxnWriteSize adds the transaction write size and keys. +func (t *TxnWriteThroughputSLI) AddTxnWriteSize(size, keys int) { + t.writeSize += size + t.writeKeys += keys +} + +func (t *TxnWriteThroughputSLI) reportMetric() { + if t.IsInvalid() { + return + } + if t.IsSmallTxn() { + metrics.SmallTxnWriteDuration.Observe(t.writeTime.Seconds()) + } else { + metrics.TxnWriteThroughput.Observe(float64(t.writeSize) / t.writeTime.Seconds()) + } +} + +// SetInvalid marks this transaction is invalid to report SLI metrics. +func (t *TxnWriteThroughputSLI) SetInvalid() { + t.invalid = true +} + +// IsInvalid checks the transaction is valid to report SLI metrics. Currently, the following case will cause invalid: +// 1. The transaction contains `insert|replace into ... select ... from ...` statement. +// 2. The write SQL statement has more read keys than write keys. +func (t *TxnWriteThroughputSLI) IsInvalid() bool { + return t.invalid || t.readKeys > t.writeKeys || t.writeSize == 0 || t.writeTime == 0 +} + +const ( + smallTxnAffectRow = 20 + smallTxnWriteSize = 1 * 1024 * 1024 // 1MB +) + +// IsSmallTxn exports for testing. +func (t *TxnWriteThroughputSLI) IsSmallTxn() bool { + return t.affectRow <= smallTxnAffectRow && t.writeSize <= smallTxnWriteSize +} + +// Reset exports for testing. +func (t *TxnWriteThroughputSLI) Reset() { + t.invalid = false + t.affectRow = 0 + t.writeSize = 0 + t.readKeys = 0 + t.writeKeys = 0 + t.writeTime = 0 +} + +// String exports for testing. +func (t *TxnWriteThroughputSLI) String() string { + return fmt.Sprintf("invalid: %v, affectRow: %v, writeSize: %v, readKeys: %v, writeKeys: %v, writeTime: %v", + t.invalid, t.affectRow, t.writeSize, t.readKeys, t.writeKeys, t.writeTime.String()) +} From db798765ef26ef4de8dd0cdfd1eec28119fde624 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 11 May 2021 18:35:40 +0800 Subject: [PATCH 85/85] planner: fix incorrect duration between compare (#22830) (#23233) --- expression/builtin_compare.go | 18 +++++++++++------- planner/core/expression_rewriter.go | 5 +++++ planner/core/expression_rewriter_test.go | 18 ++++++++++++++++++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 671073dfcf37e..5fe32e958249f 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -385,14 +385,18 @@ func ResolveType4Between(args [3]Expression) types.EvalType { hasTemporal := false if cmpTp == types.ETString { - for _, arg := range args { - if types.IsTypeTemporal(arg.GetType().Tp) { - hasTemporal = true - break + if args[0].GetType().Tp == mysql.TypeDuration { + cmpTp = types.ETDuration + } else { + for _, arg := range args { + if types.IsTypeTemporal(arg.GetType().Tp) { + hasTemporal = true + break + } + } + if hasTemporal { + cmpTp = types.ETDatetime } - } - if hasTemporal { - cmpTp = types.ETDatetime } } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index df9b73a778178..e36eb6efac784 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1548,6 +1548,11 @@ func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Exp } return expression.WrapWithCastAsString(ctx, e) } + case types.ETDuration: + expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDuration)) + lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDuration)) + rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDuration)) + return case types.ETDatetime: expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime)) lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime)) diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 874aba58f7686..66bb860a52f0b 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -377,3 +377,21 @@ func (s *testExpressionRewriterSuite) TestCompareMultiFieldsInSubquery(c *C) { tk.MustQuery("SELECT * FROM t3 WHERE (SELECT c1, c2 FROM t3 LIMIT 1) != ALL(SELECT c1, c2 FROM t4);").Check(testkit.Rows()) } + +func (s *testExpressionRewriterSuite) TestIssue22818(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a time);") + tk.MustExec("insert into t values(\"23:22:22\");") + tk.MustQuery("select * from t where a between \"23:22:22\" and \"23:22:22\"").Check( + testkit.Rows("23:22:22")) +}