diff --git a/cmd/ddltest/ddl_test.go b/cmd/ddltest/ddl_test.go index d939e0d0258ee..52009b10de142 100644 --- a/cmd/ddltest/ddl_test.go +++ b/cmd/ddltest/ddl_test.go @@ -41,6 +41,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store" tidbdriver "github.com/pingcap/tidb/store/driver" "github.com/pingcap/tidb/table" @@ -559,7 +560,7 @@ func (s *TestDDLSuite) Bootstrap(c *C) { tk.MustExec("create table test_mixed (c1 int, c2 int, primary key(c1))") tk.MustExec("create table test_inc (c1 int, c2 int, primary key(c1))") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists test_insert_common, test_conflict_insert_common, " + "test_update_common, test_conflict_update_common, test_delete_common, test_conflict_delete_common, " + "test_mixed_common, test_inc_common") @@ -571,7 +572,7 @@ func (s *TestDDLSuite) Bootstrap(c *C) { tk.MustExec("create table test_conflict_delete_common (c1 int, c2 int, primary key(c1, c2))") tk.MustExec("create table test_mixed_common (c1 int, c2 int, primary key(c1, c2))") tk.MustExec("create table test_inc_common (c1 int, c2 int, primary key(c1, c2))") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly } func (s *TestDDLSuite) TestSimple(c *C) { diff --git a/cmd/explaintest/main.go b/cmd/explaintest/main.go index b54539959bedf..fa5265f7af871 100644 --- a/cmd/explaintest/main.go +++ b/cmd/explaintest/main.go @@ -94,7 +94,6 @@ func newTester(name string) *tester { t.enableQueryLog = true t.ctx = mock.NewContext() t.ctx.GetSessionVars().EnableWindowFunction = true - t.ctx.GetSessionVars().IntPrimaryKeyDefaultAsClustered = true return t } @@ -658,7 +657,6 @@ func main() { "set @@tidb_projection_concurrency=4", "set @@tidb_distsql_scan_concurrency=15", "set @@global.tidb_enable_clustered_index=0;", - "set @@tidb_int_primary_key_default_as_clustered=1", } for _, sql := range resets { if _, err = mdb.Exec(sql); err != nil { diff --git a/config/config.go b/config/config.go index ae645526391eb..87e286ab31aef 100644 --- a/config/config.go +++ b/config/config.go @@ -126,6 +126,8 @@ type Config struct { IndexLimit int `toml:"index-limit" json:"index-limit"` TableColumnCountLimit uint32 `toml:"table-column-count-limit" json:"table-column-count-limit"` GracefulWaitBeforeShutdown int `toml:"graceful-wait-before-shutdown" json:"graceful-wait-before-shutdown"` + // AlterPrimaryKey is used to control alter primary key feature. + AlterPrimaryKey bool `toml:"alter-primary-key" json:"alter-primary-key"` // TreatOldVersionUTF8AsUTF8MB4 is use to treat old version table/column UTF8 charset as UTF8MB4. This is for compatibility. // Currently not support dynamic modify, because this need to reload all old version schema. TreatOldVersionUTF8AsUTF8MB4 bool `toml:"treat-old-version-utf8-as-utf8mb4" json:"treat-old-version-utf8-as-utf8mb4"` @@ -560,6 +562,7 @@ var defaultConf = Config{ MaxIndexLength: 3072, IndexLimit: 64, TableColumnCountLimit: 1017, + AlterPrimaryKey: false, TreatOldVersionUTF8AsUTF8MB4: true, EnableTableLock: false, DelayCleanTableLock: 0, diff --git a/config/config.toml.example b/config/config.toml.example index a2b55143c4f62..e4ebe7a5defc6 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -86,6 +86,12 @@ delay-clean-table-lock = 0 # Maximum number of the splitting region, which is used by the split region statement. split-region-max-num = 1000 +# alter-primary-key is used to control whether the primary keys are clustered. +# Note that this config is deprecated. Only valid when @@global.tidb_enable_clustered_index = 'int_only'. +# Default is false, only the integer primary keys are clustered. +# If it is true, all types of primary keys are nonclustered. +alter-primary-key = false + # server-version is used to change the version string of TiDB in the following scenarios: # 1. the server version returned by builtin-function `VERSION()`. # 2. the server version filled in handshake packets of MySQL Connection Protocol, see https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake for more details. diff --git a/config/config_test.go b/config/config_test.go index 62cc87d2a713e..25fea0d1c022e 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -184,6 +184,7 @@ unrecognized-option-test = true _, err = f.WriteString(` token-limit = 0 enable-table-lock = true +alter-primary-key = true delay-clean-table-lock = 5 split-region-max-num=10000 enable-batch-dml = true @@ -243,6 +244,7 @@ spilled-file-encryption-method = "plaintext" // Test that the value will be overwritten by the config file. c.Assert(conf.Performance.TxnTotalSizeLimit, Equals, uint64(2000)) + c.Assert(conf.AlterPrimaryKey, Equals, true) c.Assert(conf.Performance.TCPNoDelay, Equals, false) c.Assert(conf.TiKVClient.CommitTimeout, Equals, "41s") diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 99a2de6ee3ef5..046f6e07b16cb 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -793,9 +793,10 @@ func (s *testStateChangeSuite) TestWriteOnlyForDropColumn(c *C) { c.Assert(err, IsNil) }() - sqls := make([]sqlWithErr, 2) + sqls := make([]sqlWithErr, 3) sqls[0] = sqlWithErr{"update t set c1='5', c3='2020-03-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} - sqls[1] = sqlWithErr{"update t t1, tt t2 set t1.c1='5', t1.c3='2020-03-01', t2.c1='10' where t1.c4=t2.c4", + sqls[1] = sqlWithErr{"update t set c1='5', c3='2020-03-01' where c4 = 8;", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} + sqls[2] = sqlWithErr{"update t t1, tt t2 set t1.c1='5', t1.c3='2020-03-01', t2.c1='10' where t1.c4=t2.c4", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} // TODO: Fix the case of sqls[2]. // sqls[2] = sqlWithErr{"update t set c1='5' where c3='2017-07-01';", errors.New("[planner:1054]Unknown column 'c3' in 'field list'")} diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 7e27382395dd3..c73685358a9cf 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/store/tikv/oracle" @@ -2638,7 +2639,7 @@ func (s *testIntegrationSuite7) TestDuplicateErrorMessage(c *C) { config.UpdateGlobal(func(conf *config.Config) { conf.EnableGlobalIndex = globalIndex }) - for _, clusteredIndex := range []bool{false, true} { + for _, clusteredIndex := range []variable.ClusteredIndexDefMode{variable.ClusteredIndexDefModeOn, variable.ClusteredIndexDefModeOff, variable.ClusteredIndexDefModeIntOnly} { tk.Se.GetSessionVars().EnableClusteredIndex = clusteredIndex for _, t := range tests { tk.MustExec("drop table if exists t;") diff --git a/ddl/db_test.go b/ddl/db_test.go index d45a29fb84db5..e8abc6674c708 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -178,7 +178,7 @@ func (s *testDBSuite7) TestAddIndexWithPK(c *C) { tk.MustExec("use " + s.schemaName) testAddIndexWithPK(tk) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn testAddIndexWithPK(tk) } @@ -1056,7 +1056,7 @@ func (s *testDBSuite6) TestAddMultiColumnsIndexClusterIndex(c *C) { tk.MustExec("create database test_add_multi_col_index_clustered;") tk.MustExec("use test_add_multi_col_index_clustered;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a int, b varchar(10), c int, primary key (a, b));") tk.MustExec("insert into t values (1, '1', 1), (2, '2', NULL), (3, '3', 3);") tk.MustExec("create index idx on t (a, c);") @@ -1156,7 +1156,7 @@ func testAddIndex(c *C, store kv.Storage, lease time.Duration, tp testAddIndexTy case testPartition: tk.MustExec("set @@session.tidb_enable_table_partition = '1';") case testClusteredIndex: - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn } tk.MustExec("drop table if exists test_add_index") tk.MustExec(createTableSQL) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 25f9fb81d6338..d7c7d214a933b 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1523,8 +1523,14 @@ func isSingleIntPK(constr *ast.Constraint, lastCol *model.ColumnInfo) bool { // ShouldBuildClusteredIndex is used to determine whether the CREATE TABLE statement should build a clustered index table. func ShouldBuildClusteredIndex(ctx sessionctx.Context, opt *ast.IndexOption, isSingleIntPK bool) bool { if opt == nil || opt.PrimaryKeyTp == model.PrimaryKeyTypeDefault { - return ctx.GetSessionVars().EnableClusteredIndex || - (isSingleIntPK && ctx.GetSessionVars().IntPrimaryKeyDefaultAsClustered) + switch ctx.GetSessionVars().EnableClusteredIndex { + case variable.ClusteredIndexDefModeOn: + return true + case variable.ClusteredIndexDefModeIntOnly: + return !config.GetGlobalConfig().AlterPrimaryKey && isSingleIntPK + default: + return false + } } return opt.PrimaryKeyTp == model.PrimaryKeyTypeClustered } diff --git a/ddl/failtest/fail_db_test.go b/ddl/failtest/fail_db_test.go index 58f0f2d79d6ce..805ee67154c54 100644 --- a/ddl/failtest/fail_db_test.go +++ b/ddl/failtest/fail_db_test.go @@ -354,7 +354,7 @@ func (s *testFailDBSuite) TestAddIndexWorkerNum(c *C) { tk.MustExec("use test_db") tk.MustExec("drop table if exists test_add_index") if s.IsCommonHandle { - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table test_add_index (c1 bigint, c2 bigint, c3 bigint, primary key(c1, c3))") } else { tk.MustExec("create table test_add_index (c1 bigint, c2 bigint, c3 bigint, primary key(c1))") diff --git a/ddl/partition.go b/ddl/partition.go index 16b20581d7dfe..4cc71eb1c8d74 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -129,16 +129,17 @@ func (w *worker) onAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (v if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { // For available state, the new added partition should wait it's replica to // be finished. Otherwise the query to this partition will be blocked. - needWait, err := checkPartitionReplica(addingDefinitions, d) + needRetry, err := checkPartitionReplica(tblInfo.TiFlashReplica.Count, addingDefinitions, d) if err != nil { ver, err = convertAddTablePartitionJob2RollbackJob(t, job, err, tblInfo) return ver, err } - if needWait { + if needRetry { // The new added partition hasn't been replicated. // Do nothing to the job this time, wait next worker round. time.Sleep(tiflashCheckTiDBHTTPAPIHalfInterval) - return ver, nil + // Set the error here which will lead this job exit when it's retry times beyond the limitation. + return ver, errors.Errorf("[ddl] add partition wait for tiflash replica to complete") } } @@ -222,13 +223,29 @@ func checkAddPartitionValue(meta *model.TableInfo, part *model.PartitionInfo) er return nil } -func checkPartitionReplica(addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { +func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { + failpoint.Inject("mockWaitTiFlashReplica", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(true, nil) + } + }) + ctx := context.Background() pdCli := d.store.(tikv.Storage).GetRegionCache().PDClient() stores, err := pdCli.GetAllStores(ctx) if err != nil { return needWait, errors.Trace(err) } + // Check whether stores have `count` tiflash engines. + tiFlashStoreCount := uint64(0) + for _, store := range stores { + if storeHasEngineTiFlashLabel(store) { + tiFlashStoreCount++ + } + } + if replicaCount > tiFlashStoreCount { + return false, errors.Errorf("[ddl] the tiflash replica count: %d should be less than the total tiflash server count: %d", replicaCount, tiFlashStoreCount) + } for _, pd := range addingDefinitions { startKey, endKey := tablecodec.GetTableHandleKeyRange(pd.ID) regions, err := pdCli.ScanRegions(ctx, startKey, endKey, -1) diff --git a/ddl/partition_test.go b/ddl/partition_test.go index 6173fe0b5c599..32078133547b6 100644 --- a/ddl/partition_test.go +++ b/ddl/partition_test.go @@ -17,14 +17,16 @@ import ( "context" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" ) -var _ = Suite(&testPartitionSuite{}) +var _ = SerialSuites(&testPartitionSuite{}) type testPartitionSuite struct { store kv.Storage @@ -155,3 +157,111 @@ func testTruncatePartition(c *C, ctx sessionctx.Context, d *ddl, dbInfo *model.D checkHistoryJobArgs(c, ctx, job.ID, &historyJobArgs{ver: v, tbl: tblInfo}) return job } + +func testAddPartition(c *C, ctx sessionctx.Context, d *ddl, dbInfo *model.DBInfo, tblInfo *model.TableInfo) error { + ids, err := d.genGlobalIDs(1) + c.Assert(err, IsNil) + partitionInfo := &model.PartitionInfo{ + Type: model.PartitionTypeRange, + Expr: tblInfo.Columns[0].Name.L, + Enable: true, + Definitions: []model.PartitionDefinition{ + { + ID: ids[0], + Name: model.NewCIStr("p2"), + LessThan: []string{"300"}, + }, + }, + } + addPartitionJob := &model.Job{ + SchemaID: dbInfo.ID, + TableID: tblInfo.ID, + Type: model.ActionAddTablePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []interface{}{partitionInfo}, + } + return d.doDDLJob(ctx, addPartitionJob) +} + +func (s *testPartitionSuite) TestAddPartitionReplicaBiggerThanTiFlashStores(c *C) { + d := testNewDDLAndStart( + context.Background(), + c, + WithStore(s.store), + WithLease(testLease), + ) + defer func() { + err := d.Stop() + c.Assert(err, IsNil) + }() + dbInfo := testSchemaInfo(c, d, "test_partition2") + testCreateSchema(c, testNewContext(d), d, dbInfo) + // Build a tableInfo with replica count = 1 while there is no real tiFlash store. + tblInfo := buildTableInfoWithReplicaInfo(c, d) + ctx := testNewContext(d) + testCreateTable(c, ctx, d, dbInfo, tblInfo) + + err := testAddPartition(c, ctx, d, dbInfo, tblInfo) + // Since there is no real TiFlash store (less than replica count), adding a partition will error here. + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:-1][ddl] the tiflash replica count: 1 should be less than the total tiflash server count: 0") + + // Test `add partition` waiting TiFlash replica can exit when its retry count is beyond the limitation. + originErrCountLimit := variable.GetDDLErrorCountLimit() + variable.SetDDLErrorCountLimit(3) + defer func() { + variable.SetDDLErrorCountLimit(originErrCountLimit) + }() + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockWaitTiFlashReplica", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockWaitTiFlashReplica"), IsNil) + }() + err = testAddPartition(c, ctx, d, dbInfo, tblInfo) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:-1]DDL job rollback, error msg: [ddl] add partition wait for tiflash replica to complete") +} + +func buildTableInfoWithReplicaInfo(c *C, d *ddl) *model.TableInfo { + tbl := &model.TableInfo{ + Name: model.NewCIStr("t1"), + } + col := &model.ColumnInfo{ + Name: model.NewCIStr("c"), + Offset: 0, + State: model.StatePublic, + FieldType: *types.NewFieldType(mysql.TypeLong), + ID: allocateColumnID(tbl), + } + genIDs, err := d.genGlobalIDs(1) + c.Assert(err, IsNil) + tbl.ID = genIDs[0] + tbl.Columns = []*model.ColumnInfo{col} + tbl.Charset = "utf8" + tbl.Collate = "utf8_bin" + tbl.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + + partIDs, err := d.genGlobalIDs(2) + c.Assert(err, IsNil) + partInfo := &model.PartitionInfo{ + Type: model.PartitionTypeRange, + Expr: tbl.Columns[0].Name.L, + Enable: true, + Definitions: []model.PartitionDefinition{ + { + ID: partIDs[0], + Name: model.NewCIStr("p0"), + LessThan: []string{"100"}, + }, + { + ID: partIDs[1], + Name: model.NewCIStr("p1"), + LessThan: []string{"200"}, + }, + }, + } + tbl.Partition = partInfo + return tbl +} diff --git a/ddl/serial_test.go b/ddl/serial_test.go index 43b4773cadc85..35b923df5847e 100644 --- a/ddl/serial_test.go +++ b/ddl/serial_test.go @@ -117,7 +117,7 @@ func (s *testIntegrationSuite9) TestPrimaryKey(c *C) { tk.MustExec("drop database if exists test_primary_key;") tk.MustExec("create database test_primary_key;") tk.MustExec("use test_primary_key;") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly // Test add/drop primary key on a plain table. tk.MustExec("drop table if exists t;") @@ -325,7 +325,7 @@ func (s *testIntegrationSuite9) TestMultiRegionGetTableEndCommonHandle(c *C) { tk.MustExec("drop database if exists test_get_endhandle") tk.MustExec("create database test_get_endhandle") tk.MustExec("use test_get_endhandle") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(a varchar(20), b int, c float, d bigint, primary key (a, b, c))") var builder strings.Builder @@ -369,7 +369,7 @@ func (s *testIntegrationSuite9) TestGetTableEndCommonHandle(c *C) { tk.MustExec("drop database if exists test_get_endhandle") tk.MustExec("create database test_get_endhandle") tk.MustExec("use test_get_endhandle") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(a varchar(15), b bigint, c int, primary key (a, b))") tk.MustExec("create table t1(a varchar(15), b bigint, c int, primary key (a(2), b))") @@ -1406,7 +1406,7 @@ func (s *testIntegrationSuite9) TestInvisibleIndex(c *C) { func (s *testIntegrationSuite9) TestCreateClusteredIndex(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("CREATE TABLE t1 (a int primary key, b int)") tk.MustExec("CREATE TABLE t2 (a varchar(255) primary key, b int)") tk.MustExec("CREATE TABLE t3 (a int, b int, c int, primary key (a, b))") @@ -1447,7 +1447,7 @@ func (s *testIntegrationSuite9) TestCreateClusteredIndex(c *C) { c.Assert(err, IsNil) c.Assert(tbl.Meta().IsCommonHandle, IsTrue) - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("CREATE TABLE t7 (a varchar(255) primary key, b int)") is = domain.GetDomain(ctx).InfoSchema() tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t7")) diff --git a/distsql/select_result.go b/distsql/select_result.go index 6f76058e79327..9e37a02796f4f 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -111,30 +111,30 @@ func (r *selectResult) fetchResp(ctx context.Context) error { if r.stats != nil { coprCacheHistogramHit.Observe(float64(r.stats.CoprCacheHitNum)) coprCacheHistogramMiss.Observe(float64(len(r.stats.copRespTime) - int(r.stats.CoprCacheHitNum))) - if len(r.stats.copRespTime) > 0 { + // Ignore internal sql. + if !r.ctx.GetSessionVars().InRestrictedSQL && len(r.stats.copRespTime) > 0 { ratio := float64(r.stats.CoprCacheHitNum) / float64(len(r.stats.copRespTime)) - switch { - case ratio >= 0: - telemetry.CurrentCoprCacheHitRatioGTE0Count.Inc() - fallthrough - case ratio >= 0.01: - telemetry.CurrentCoprCacheHitRatioGTE1Count.Inc() - fallthrough - case ratio >= 0.1: - telemetry.CurrentCoprCacheHitRatioGTE10Count.Inc() - fallthrough - case ratio >= 0.2: - telemetry.CurrentCoprCacheHitRatioGTE20Count.Inc() - fallthrough - case ratio >= 0.4: - telemetry.CurrentCoprCacheHitRatioGTE40Count.Inc() - fallthrough - case ratio >= 0.8: - telemetry.CurrentCoprCacheHitRatioGTE80Count.Inc() - fallthrough - case ratio >= 1: + if ratio >= 1 { telemetry.CurrentCoprCacheHitRatioGTE100Count.Inc() } + if ratio >= 0.8 { + telemetry.CurrentCoprCacheHitRatioGTE80Count.Inc() + } + if ratio >= 0.4 { + telemetry.CurrentCoprCacheHitRatioGTE40Count.Inc() + } + if ratio >= 0.2 { + telemetry.CurrentCoprCacheHitRatioGTE20Count.Inc() + } + if ratio >= 0.1 { + telemetry.CurrentCoprCacheHitRatioGTE10Count.Inc() + } + if ratio >= 0.01 { + telemetry.CurrentCoprCacheHitRatioGTE1Count.Inc() + } + if ratio >= 0 { + telemetry.CurrentCoprCacheHitRatioGTE0Count.Inc() + } } } }() diff --git a/executor/admin_test.go b/executor/admin_test.go index 7241f5cf52f77..c9cda897a4745 100644 --- a/executor/admin_test.go +++ b/executor/admin_test.go @@ -24,6 +24,7 @@ import ( mysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" @@ -194,7 +195,7 @@ func (s *testSuite5) TestClusteredIndexAdminRecoverIndex(c *C) { tk.MustExec("drop database if exists test_cluster_index_admin_recover;") tk.MustExec("create database test_cluster_index_admin_recover;") tk.MustExec("use test_cluster_index_admin_recover;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn dbName := model.NewCIStr("test_cluster_index_admin_recover") tblName := model.NewCIStr("t") @@ -310,7 +311,7 @@ func (s *testSuite5) TestAdminRecoverIndex1(c *C) { sc := s.ctx.GetSessionVars().StmtCtx tk.MustExec("use test") tk.MustExec("drop table if exists admin_test") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table admin_test (c1 varchar(255), c2 int, c3 int default 1, primary key(c1), unique key(c2))") tk.MustExec("insert admin_test (c1, c2) values ('1', 1), ('2', 2), ('3', 3), ('10', 10), ('20', 20)") @@ -515,7 +516,7 @@ func (s *testSuite5) TestAdminCleanupIndexPKNotHandle(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists admin_test") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table admin_test (c1 int, c2 int, c3 int, primary key (c1, c2))") tk.MustExec("insert admin_test (c1, c2) values (1, 2), (3, 4), (-5, 5)") @@ -627,7 +628,7 @@ func (s *testSuite5) TestClusteredAdminCleanupIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists admin_test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table admin_test (c1 varchar(255), c2 int, c3 char(10) default 'c3', primary key (c1, c3), unique key(c2), key (c3))") tk.MustExec("insert admin_test (c1, c2) values ('c1_1', 2), ('c1_2', 4), ('c1_3', NULL)") tk.MustExec("insert admin_test (c1, c3) values ('c1_4', 'c3_4'), ('c1_5', 'c3_5'), ('c1_6', default)") diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index e967016ecfb13..0c327fb9c28b7 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/executor" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testutil" @@ -919,7 +920,7 @@ func (s *testSuiteAgg) TestAggEliminator(c *C) { func (s *testSuiteAgg) TestClusterIndexMaxMinEliminator(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("drop table if exists t;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a int, b int, c int, primary key(a, b));") for i := 0; i < 10+1; i++ { tk.MustExec("insert into t values (?, ?, ?)", i, i, i) diff --git a/executor/analyze_test.go b/executor/analyze_test.go index 623bd09277948..5b5c1b66240f0 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -129,7 +129,7 @@ func (s *testSuite1) TestClusterIndexAnalyze(c *C) { tk.MustExec("drop database if exists test_cluster_index_analyze;") tk.MustExec("create database test_cluster_index_analyze;") tk.MustExec("use test_cluster_index_analyze;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a int, b int, c int, primary key(a, b));") for i := 0; i < 100; i++ { @@ -832,7 +832,7 @@ func (s *testSuite1) TestNormalAnalyzeOnCommonHandle(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t1, t2, t3, t4") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("CREATE TABLE t1 (a int primary key, b int)") tk.MustExec("insert into t1 values(1,1), (2,2), (3,3)") tk.MustExec("CREATE TABLE t2 (a varchar(255) primary key, b int)") diff --git a/executor/batch_point_get_test.go b/executor/batch_point_get_test.go index c41ac03156243..926834dc9281b 100644 --- a/executor/batch_point_get_test.go +++ b/executor/batch_point_get_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/testkit" ) @@ -177,7 +178,7 @@ func (s *testBatchPointGetSuite) TestBatchPointGetLockExistKey(c *C) { errCh <- tk1.ExecToErr("use test") errCh <- tk2.ExecToErr("use test") - tk1.Se.GetSessionVars().EnableClusteredIndex = false + tk1.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly errCh <- tk1.ExecToErr(fmt.Sprintf("drop table if exists %s", tableName)) errCh <- tk1.ExecToErr(fmt.Sprintf("create table %s(id int, v int, k int, %s key0(id, v))", tableName, key)) diff --git a/executor/builder.go b/executor/builder.go index 853d4fcf43d71..eaa345875967c 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1842,6 +1842,15 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor { } base := newBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec) base.initCap = chunk.ZeroCapacity + var assignFlag []int + assignFlag, b.err = getAssignFlag(b.ctx, v, selExec.Schema().Len()) + if b.err != nil { + return nil + } + b.err = plannercore.CheckUpdateList(assignFlag, v) + if b.err != nil { + return nil + } updateExec := &UpdateExec{ baseExecutor: base, OrderedList: v.OrderedList, @@ -1850,10 +1859,29 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor { multiUpdateOnSameTable: multiUpdateOnSameTable, tblID2table: tblID2table, tblColPosInfos: v.TblColPosInfos, + assignFlag: assignFlag, } return updateExec } +func getAssignFlag(ctx sessionctx.Context, v *plannercore.Update, schemaLen int) ([]int, error) { + assignFlag := make([]int, schemaLen) + for i := range assignFlag { + assignFlag[i] = -1 + } + for _, assign := range v.OrderedList { + if !ctx.GetSessionVars().AllowWriteRowID && assign.Col.ID == model.ExtraHandleID { + return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") + } + tblIdx, found := v.TblColPosInfos.FindTblIdx(assign.Col.Index) + if found { + colIdx := assign.Col.Index + assignFlag[colIdx] = tblIdx + } + } + return assignFlag, nil +} + func (b *executorBuilder) buildDelete(v *plannercore.Delete) Executor { tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos)) for _, info := range v.TblColPosInfos { diff --git a/executor/delete_test.go b/executor/delete_test.go index 278692acdb717..55ab9b80fd1d7 100644 --- a/executor/delete_test.go +++ b/executor/delete_test.go @@ -18,6 +18,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/testkit" ) @@ -82,7 +83,7 @@ func (s *testSuite8) TestDeleteLockKey(c *C) { tk1, tk2 := testkit.NewTestKit(c, s.store), testkit.NewTestKit(c, s.store) tk1.MustExec("use test") tk2.MustExec("use test") - tk1.Se.GetSessionVars().EnableClusteredIndex = false + tk1.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk1.MustExec(t.ddl) tk1.MustExec(t.pre) tk1.MustExec("begin pessimistic") diff --git a/executor/distsql_test.go b/executor/distsql_test.go index c98f7f9e2f70c..6d36f06bf34b1 100644 --- a/executor/distsql_test.go +++ b/executor/distsql_test.go @@ -151,11 +151,11 @@ func (s *testSuite3) TestCorColToRanges(c *C) { tk.MustExec("insert into t values(1, 1, 1), (2, 2 ,2), (3, 3, 3), (4, 4, 4), (5, 5, 5), (6, 6, 6), (7, 7, 7), (8, 8, 8), (9, 9, 9)") tk.MustExec("analyze table t") // Test single read on table. - tk.MustQuery("select t.c in (select count(*) from t s ignore index(idx), t t1 where s.a = t.a and s.a = t1.a) from t").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) + tk.MustQuery("select t.c in (select count(*) from t s ignore index(idx), t t1 where s.a = t.a and s.a = t1.a) from t order by 1 desc").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) // Test single read on index. - tk.MustQuery("select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) + tk.MustQuery("select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t order by 1 desc").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) // Test IndexLookUpReader. - tk.MustQuery("select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) + tk.MustQuery("select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t order by 1 desc").Check(testkit.Rows("1", "0", "0", "0", "0", "0", "0", "0", "0")) } func (s *testSuiteP1) TestUniqueKeyNullValueSelect(c *C) { diff --git a/executor/executor_test.go b/executor/executor_test.go index 535c0e832c547..fadaefa920b66 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -2351,7 +2351,7 @@ func (s *testSuiteP2) TestClusteredIndexIsPointGet(c *C) { tk.MustExec("create database test_cluster_index_is_point_get;") tk.MustExec("use test_cluster_index_is_point_get;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t;") tk.MustExec("create table t (a varchar(255), b int, c char(10), primary key (c, a));") ctx := tk.Se.(sessionctx.Context) @@ -3648,7 +3648,7 @@ func (s *testSuite) TestUnsignedPk(c *C) { func (s *testSuite) TestSignedCommonHandle(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t") tk.MustExec("create table t(k1 int, k2 int, primary key(k1, k2))") tk.MustExec("insert into t(k1, k2) value(-100, 1), (-50, 1), (0, 0), (1, 1), (3, 3)") @@ -3816,7 +3816,7 @@ func (s *testSuite) TestCheckTableClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists admin_test;") tk.MustExec("create table admin_test (c1 int, c2 int, c3 int default 1, primary key (c1, c2), index (c1), unique key(c2));") tk.MustExec("insert admin_test (c1, c2) values (1, 1), (2, 2), (3, 3);") @@ -4130,6 +4130,173 @@ func (s *testSuiteP1) TestUnionAutoSignedCast(c *C) { Check(testkit.Rows("1 1", "2 -1", "3 -1")) } +func (s *testSuiteP1) TestUpdateClustered(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + type resultChecker struct { + check string + assert []string + } + + for _, clustered := range []string{"", "clustered"} { + tests := []struct { + initSchema []string + initData []string + dml string + resultCheck []resultChecker + }{ + { // left join + update both + match & unmatched + pk + []string{ + "drop table if exists a, b", + "create table a (k1 int, k2 int, v int)", + fmt.Sprintf("create table b (a int not null, k1 int, k2 int, v int, primary key(k1, k2) %s)", clustered), + }, + []string{ + "insert into a values (1, 1, 1), (2, 2, 2)", // unmatched + matched + "insert into b values (2, 2, 2, 2)", + }, + "update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.v = 20, b.v = 100, a.k1 = a.k1 + 1, b.k1 = b.k1 + 1, a.k2 = a.k2 + 2, b.k2 = b.k2 + 2", + []resultChecker{ + { + "select * from b", + []string{"2 3 4 100"}, + }, + { + "select * from a", + []string{"2 3 20", "3 4 20"}, + }, + }, + }, + { // left join + update both + match & unmatched + pk + []string{ + "drop table if exists a, b", + "create table a (k1 int, k2 int, v int)", + fmt.Sprintf("create table b (a int not null, k1 int, k2 int, v int, primary key(k1, k2) %s)", clustered), + }, + []string{ + "insert into a values (1, 1, 1), (2, 2, 2)", // unmatched + matched + "insert into b values (2, 2, 2, 2)", + }, + "update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100", + []resultChecker{ + { + "select * from b", + []string{"2 3 4 100"}, + }, + { + "select * from a", + []string{"2 3 20", "3 4 20"}, + }, + }, + }, + { // left join + update both + match & unmatched + prefix pk + []string{ + "drop table if exists a, b", + "create table a (k1 varchar(100), k2 varchar(100), v varchar(100))", + fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered), + }, + []string{ + "insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched + "insert into b values ('22', '22', '22', '22')", + }, + "update a left join b on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100", + []resultChecker{ + { + "select * from b", + []string{"22 23 24 100"}, + }, + { + "select * from a", + []string{"12 13 20", "23 24 20"}, + }, + }, + }, + { // right join + update both + match & unmatched + prefix pk + []string{ + "drop table if exists a, b", + "create table a (k1 varchar(100), k2 varchar(100), v varchar(100))", + fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered), + }, + []string{ + "insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched + "insert into b values ('22', '22', '22', '22')", + }, + "update b right join a on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100", + []resultChecker{ + { + "select * from b", + []string{"22 23 24 100"}, + }, + { + "select * from a", + []string{"12 13 20", "23 24 20"}, + }, + }, + }, + { // inner join + update both + match & unmatched + prefix pk + []string{ + "drop table if exists a, b", + "create table a (k1 varchar(100), k2 varchar(100), v varchar(100))", + fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered), + }, + []string{ + "insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched + "insert into b values ('22', '22', '22', '22')", + }, + "update b join a on a.k1 = b.k1 and a.k2 = b.k2 set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, b.k1 = b.k1 + 1, b.k2 = b.k2 + 2, a.v = 20, b.v = 100", + []resultChecker{ + { + "select * from b", + []string{"22 23 24 100"}, + }, + { + "select * from a", + []string{"11 11 11", "23 24 20"}, + }, + }, + }, + { + []string{ + "drop table if exists a, b", + "create table a (k1 varchar(100), k2 varchar(100), v varchar(100))", + fmt.Sprintf("create table b (a varchar(100) not null, k1 varchar(100), k2 varchar(100), v varchar(100), primary key(k1(1), k2(1)) %s, key kk1(k1(1), v(1)))", clustered), + }, + []string{ + "insert into a values ('11', '11', '11'), ('22', '22', '22')", // unmatched + matched + "insert into b values ('22', '22', '22', '22')", + }, + "update a set a.k1 = a.k1 + 1, a.k2 = a.k2 + 2, a.v = 20 where exists (select 1 from b where a.k1 = b.k1 and a.k2 = b.k2)", + []resultChecker{ + { + "select * from b", + []string{"22 22 22 22"}, + }, + { + "select * from a", + []string{"11 11 11", "23 24 20"}, + }, + }, + }, + } + + for _, test := range tests { + for _, s := range test.initSchema { + tk.MustExec(s) + } + for _, s := range test.initData { + tk.MustExec(s) + } + tk.MustExec(test.dml) + for _, checker := range test.resultCheck { + tk.MustQuery(checker.check).Check(testkit.Rows(checker.assert...)) + } + tk.MustExec("admin check table a") + tk.MustExec("admin check table b") + } + } +} + func (s *testSuite6) TestUpdateJoin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -4260,7 +4427,7 @@ func (s *testSuite3) TestRowID(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) tk.MustExec(`drop table if exists t`) - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec(`create table t(a varchar(10), b varchar(10), c varchar(1), index idx(a, b, c));`) tk.MustExec(`insert into t values('a', 'b', 'c');`) tk.MustExec(`insert into t values('a', 'b', 'c');`) @@ -4836,7 +5003,7 @@ func (s *testSplitTable) TestClusterIndexSplitTableIntegration(c *C) { tk.MustExec("drop database if exists test_cluster_index_index_split_table_integration;") tk.MustExec("create database test_cluster_index_index_split_table_integration;") tk.MustExec("use test_cluster_index_index_split_table_integration;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a varchar(255), b double, c int, primary key (a, b));") @@ -4891,7 +5058,7 @@ func (s *testSplitTable) TestClusterIndexShowTableRegion(c *C) { tk.MustExec("drop database if exists cluster_index_regions;") tk.MustExec("create database cluster_index_regions;") tk.MustExec("use cluster_index_regions;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a int, b int, c int, primary key(a, b));") tk.MustExec("insert t values (1, 1, 1), (2, 2, 2);") tk.MustQuery("split table t between (1, 0) and (2, 3) regions 2;").Check(testkit.Rows("1 1")) @@ -4914,7 +5081,7 @@ func (s *testSplitTable) TestClusterIndexShowTableRegion(c *C) { func (s *testSuiteWithData) TestClusterIndexOuterJoinElimination(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a int, b int, c int, primary key(a,b))") rows := tk.MustQuery(`explain format = 'brief' select t1.a from t t1 left join t t2 on t1.a = t2.a and t1.b = t2.b`).Rows() rowStrs := s.testData.ConvertRowsToStrings(rows) diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 955b1605cf727..9aa42dda62fc0 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/israce" "github.com/pingcap/tidb/util/kvcache" @@ -445,7 +446,7 @@ func (s *testPrepareSerialSuite) TestPointGetUserVarPlanCache(c *C) { tk.MustExec("use test") tk.MustExec("set @@tidb_enable_collect_execution_info=0;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t1") tk.MustExec("CREATE TABLE t1 (a BIGINT, b VARCHAR(40), PRIMARY KEY (a, b))") tk.MustExec("INSERT INTO t1 VALUES (1,'3'),(2,'4')") diff --git a/executor/infoschema_reader_test.go b/executor/infoschema_reader_test.go index bb3497ac564a8..41bad78197f6e 100644 --- a/executor/infoschema_reader_test.go +++ b/executor/infoschema_reader_test.go @@ -873,10 +873,10 @@ func (s *testInfoschemaTableSuite) TestTablesPKType(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("create table t_int (a int primary key, b int)") tk.MustQuery("SELECT TIDB_PK_TYPE FROM information_schema.tables where table_schema = 'test' and table_name = 't_int'").Check(testkit.Rows("CLUSTERED")) - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t_implicit (a varchar(64) primary key, b int)") tk.MustQuery("SELECT TIDB_PK_TYPE FROM information_schema.tables where table_schema = 'test' and table_name = 't_implicit'").Check(testkit.Rows("NONCLUSTERED")) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t_common (a varchar(64) primary key, b int)") tk.MustQuery("SELECT TIDB_PK_TYPE FROM information_schema.tables where table_schema = 'test' and table_name = 't_common'").Check(testkit.Rows("CLUSTERED")) tk.MustQuery("SELECT TIDB_PK_TYPE FROM information_schema.tables where table_schema = 'INFORMATION_SCHEMA' and table_name = 'TABLES'").Check(testkit.Rows("NONCLUSTERED")) diff --git a/executor/insert_test.go b/executor/insert_test.go index 891a74f495660..32815899ee489 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -220,7 +220,7 @@ func (s *testSuite8) TestClusterIndexInsertOnDuplicateKey(c *C) { tk.MustExec("drop database if exists cluster_index_duplicate_entry_error;") tk.MustExec("create database cluster_index_duplicate_entry_error;") tk.MustExec("use cluster_index_duplicate_entry_error;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(a char(20), b int, primary key(a));") tk.MustExec("insert into t values('aa', 1), ('bb', 1);") @@ -237,7 +237,7 @@ func (s *testSuite8) TestClusterIndexInsertOnDuplicateKey(c *C) { func (s *testSuite10) TestPaddingCommonHandle(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`create table t1(c1 decimal(6,4), primary key(c1))`) tk.MustExec(`insert into t1 set c1 = 0.1`) tk.MustExec(`insert into t1 set c1 = 0.1 on duplicate key update c1 = 1`) @@ -1295,7 +1295,7 @@ type testSuite10 struct { func (s *testSuite10) TestClusterPrimaryTablePlainInsert(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists t1pk`) tk.MustExec(`create table t1pk(id varchar(200) primary key, v int)`) @@ -1337,7 +1337,7 @@ func (s *testSuite10) TestClusterPrimaryTablePlainInsert(c *C) { func (s *testSuite10) TestClusterPrimaryTableInsertIgnore(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists it1pk`) tk.MustExec(`create table it1pk(id varchar(200) primary key, v int)`) @@ -1363,7 +1363,7 @@ func (s *testSuite10) TestClusterPrimaryTableInsertIgnore(c *C) { func (s *testSuite10) TestClusterPrimaryTableInsertDuplicate(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists dt1pi`) tk.MustExec(`create table dt1pi(id varchar(200) primary key, v int)`) @@ -1395,7 +1395,7 @@ func (s *testSuite10) TestClusterPrimaryTableInsertDuplicate(c *C) { func (s *testSuite10) TestClusterPrimaryKeyForIndexScan(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists pkt1;") tk.MustExec("CREATE TABLE pkt1 (a varchar(255), b int, index idx(b), primary key(a,b));") @@ -1444,7 +1444,7 @@ func (s *testSerialSuite) TestDuplicateEntryMessage(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test;") - for _, enable := range []bool{true, false} { + for _, enable := range []variable.ClusteredIndexDefMode{variable.ClusteredIndexDefModeOn, variable.ClusteredIndexDefModeOff, variable.ClusteredIndexDefModeIntOnly} { tk.Se.GetSessionVars().EnableClusteredIndex = enable tk.MustExec("drop table if exists t;") tk.MustExec("create table t(a int, b char(10), unique key(b)) collate utf8mb4_general_ci;") diff --git a/executor/join_test.go b/executor/join_test.go index 76afe571f3d41..bfd0048a63b3d 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/executor" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" ) @@ -1374,7 +1375,7 @@ func (s *testSuiteJoinSerial) TestIndexNestedLoopHashJoin(c *C) { tk.MustExec("set @@tidb_init_chunk_size=2") tk.MustExec("set @@tidb_index_join_batch_size=10") tk.MustExec("DROP TABLE IF EXISTS t, s") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t(pk int primary key, a int)") for i := 0; i < 100; i++ { tk.MustExec(fmt.Sprintf("insert into t values(%d, %d)", i, i)) diff --git a/executor/mpp_gather.go b/executor/mpp_gather.go index 69a6b0c9ceb79..de51eba97b1fd 100644 --- a/executor/mpp_gather.go +++ b/executor/mpp_gather.go @@ -24,14 +24,13 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" ) func useMPPExecution(ctx sessionctx.Context, tr *plannercore.PhysicalTableReader) bool { - if !ctx.GetSessionVars().AllowMPPExecution || collate.NewCollationEnabled() { + if !ctx.GetSessionVars().AllowMPPExecution { return false } _, ok := tr.GetTablePlan().(*plannercore.PhysicalExchangeSender) @@ -139,6 +138,7 @@ func (e *MPPGather) Next(ctx context.Context, chk *chunk.Chunk) error { // Close and release the used resources. func (e *MPPGather) Close() error { + e.mppReqs = nil if e.respIter != nil { return e.respIter.Close() } diff --git a/executor/parallel_apply_test.go b/executor/parallel_apply_test.go index 6db533b971fee..b849d3d961043 100644 --- a/executor/parallel_apply_test.go +++ b/executor/parallel_apply_test.go @@ -20,6 +20,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/testkit" ) @@ -421,7 +422,7 @@ func (s *testSerialSuite) TestApplyWithOtherFeatures(c *C) { core.SetPreparedPlanCache(orgEnable) // cluster index - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t, t2") tk.MustExec("create table t(a int, b int, c int, primary key(a, b))") tk.MustExec("create table t2(a int, b int, c int, primary key(a, c))") @@ -429,7 +430,7 @@ func (s *testSerialSuite) TestApplyWithOtherFeatures(c *C) { tk.MustExec("insert into t2 values (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)") sql = "select * from t where (select min(t2.b) from t2 where t2.a > t.a) > 0" tk.MustQuery(sql).Sort().Check(testkit.Rows("1 1 1", "2 2 2", "3 3 3")) - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly // partitioning table tk.MustExec("drop table if exists t1, t2") diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 7fd32eecb2c0f..0360add8228e5 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/tablecodec" @@ -549,7 +550,7 @@ func (s *testPointGetSuite) TestReturnValues(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t (a varchar(64) primary key, b int)") tk.MustExec("insert t values ('a', 1), ('b', 2), ('c', 3)") tk.MustExec("begin pessimistic") @@ -572,7 +573,7 @@ func (s *testPointGetSuite) TestReturnValues(c *C) { func (s *testPointGetSuite) TestClusterIndexPointGet(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists pgt") tk.MustExec("create table pgt (a varchar(64), b varchar(64), uk int, v int, primary key(a, b), unique key uuk(uk))") tk.MustExec("insert pgt values ('a', 'a1', 1, 11), ('b', 'b1', 2, 22), ('c', 'c1', 3, 33)") @@ -595,7 +596,7 @@ func (s *testPointGetSuite) TestClusterIndexPointGet(c *C) { func (s *testPointGetSuite) TestClusterIndexCBOPointGet(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t1, t2") tk.MustExec(`create table t1 (a int, b decimal(10,0), c int, primary key(a,b))`) tk.MustExec(`create table t2 (a varchar(20), b int, primary key(a), unique key(b))`) @@ -795,7 +796,7 @@ func (s *testPointGetSuite) TestPointGetLockExistKey(c *C) { errCh <- tk1.ExecToErr("use test") errCh <- tk2.ExecToErr("use test") - tk1.Se.GetSessionVars().EnableClusteredIndex = false + tk1.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly errCh <- tk1.ExecToErr(fmt.Sprintf("drop table if exists %s", tableName)) errCh <- tk1.ExecToErr(fmt.Sprintf("create table %s(id int, v int, k int, %s key0(id, v))", tableName, key)) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 4900a1c77df3a..34cee0306d948 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/domain" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" ) @@ -200,7 +201,7 @@ func (s *testSerialSuite) TestPlanCacheClusterIndex(c *C) { plannercore.SetPreparedPlanCache(true) tk.MustExec("use test") tk.MustExec("drop table if exists t1") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("set @@tidb_enable_collect_execution_info=0;") tk.MustExec("create table t1(a varchar(20), b varchar(20), c varchar(20), primary key(a, b))") tk.MustExec("insert into t1 values('1','1','111'),('2','2','222'),('3','3','333')") @@ -287,7 +288,7 @@ func (s *testSerialSuite) TestPlanCacheClusterIndex(c *C) { tk.MustQuery(`execute stmt2 using @v2, @v2, @v3, @v3`).Check(testkit.Rows("b b 2 2 2", "c c 3 3 3")) // For issue 19002 - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists t1`) tk.MustExec(`create table t1(a int, b int, c int, primary key(a, b))`) tk.MustExec(`insert into t1 values(1,1,111),(2,2,222),(3,3,333)`) diff --git a/executor/rowid_test.go b/executor/rowid_test.go index 4873c637b9cb7..19ca30791f095 100644 --- a/executor/rowid_test.go +++ b/executor/rowid_test.go @@ -15,6 +15,7 @@ package executor_test import ( . "github.com/pingcap/check" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/testkit" ) @@ -64,7 +65,7 @@ func (s *testSuite1) TestExportRowID(c *C) { func (s *testSuite1) TestNotAllowWriteRowID(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table tt(id binary(10), c int, primary key(id));") tk.MustExec("insert tt values (1, 10);") // select statement diff --git a/executor/sample_test.go b/executor/sample_test.go index 262df100077bc..6760fc5135d35 100644 --- a/executor/sample_test.go +++ b/executor/sample_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/util/testkit" @@ -70,7 +71,7 @@ func (s *testTableSampleSuite) initSampleTest(c *C) *testkit.TestKit { func (s *testTableSampleSuite) TestTableSampleBasic(c *C) { tk := s.initSampleTest(c) tk.MustExec("create table t (a int);") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustQuery("select * from t tablesample regions();").Check(testkit.Rows()) tk.MustExec("insert into t values (0), (1000), (2000);") @@ -120,7 +121,7 @@ func (s *testTableSampleSuite) TestTableSampleMultiRegions(c *C) { func (s *testTableSampleSuite) TestTableSampleSchema(c *C) { tk := s.initSampleTest(c) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn // Clustered index tk.MustExec("create table t (a varchar(255) primary key, b bigint);") tk.MustExec("insert into t values ('b', 100), ('y', 100);") diff --git a/executor/set_test.go b/executor/set_test.go index b55ed2328af74..230c81bf788cd 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -537,6 +537,11 @@ func (s *testSerialSuite1) TestSetVar(c *C) { // Test issue #22145 tk.MustExec(`set global sync_relay_log = "'"`) + tk.MustExec(`set @@global.tidb_enable_clustered_index = 'int_only'`) + tk.MustQuery(`show warnings`).Check(testkit.Rows("Warning 1287 'INT_ONLY' is deprecated and will be removed in a future release. Please use 'ON' or 'OFF' instead")) + tk.MustExec(`set @@global.tidb_enable_clustered_index = 'off'`) + tk.MustQuery(`show warnings`).Check(testkit.Rows()) + } func (s *testSuite5) TestTruncateIncorrectIntSessionVar(c *C) { diff --git a/executor/table_reader.go b/executor/table_reader.go index 015fe24b65226..dea1d128559a5 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" @@ -139,28 +138,9 @@ func (e *TableReaderExecutor) Open(ctx context.Context) error { } if e.corColInAccess { ts := e.plans[0].(*plannercore.PhysicalTableScan) - access := ts.AccessCondition - if e.table.Meta().IsCommonHandle { - pkIdx := tables.FindPrimaryIndex(ts.Table) - idxCols, idxColLens := expression.IndexInfo2PrefixCols(ts.Columns, ts.Schema().Columns, pkIdx) - for _, cond := range access { - newCond, err1 := expression.SubstituteCorCol2Constant(cond) - if err1 != nil { - return err1 - } - access = append(access, newCond) - } - res, err := ranger.DetachCondAndBuildRangeForIndex(e.ctx, access, idxCols, idxColLens) - if err != nil { - return err - } - e.ranges = res.Ranges - } else { - pkTP := ts.Table.GetPkColInfo().FieldType - e.ranges, err = ranger.BuildTableRange(access, e.ctx.GetSessionVars().StmtCtx, &pkTP) - if err != nil { - return err - } + e.ranges, err = ts.ResolveCorrelatedColumns() + if err != nil { + return err } } diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 750293c9f2cec..fe1631016f5d4 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -146,7 +146,7 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) { tk.MustExec("insert into t values(2,0)") tk.MustExec("insert into t values(3,0)") - tk.MustExec("create table t1(a int not null primary key, b int not null)") + tk.MustExec("create table t1(a int primary key, b int not null)") tk.MustExec("alter table t1 set tiflash replica 1") tb = testGetTableByName(c, tk.Se, "test", "t1") err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) @@ -162,7 +162,7 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) { tk.MustQuery("select count(*) from t1 , t where t1.a = t.a").Check(testkit.Rows("3")) } // test multi-way join - tk.MustExec("create table t2(a int not null primary key, b int not null)") + tk.MustExec("create table t2(a int primary key, b int not null)") tk.MustExec("alter table t2 set tiflash replica 1") tb = testGetTableByName(c, tk.Se, "test", "t2") err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) @@ -190,6 +190,13 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) { tk.MustQuery("select count(*) k, t2.b * t2.a from t2 group by t2.b * t2.a").Check(testkit.Rows("3 0")) tk.MustQuery("select count(*) k, t2.a/2 m from t2 group by t2.a / 2 order by m").Check(testkit.Rows("1 0.5000", "1 1.0000", "1 1.5000")) tk.MustQuery("select count(*) k, t2.a div 2 from t2 group by t2.a div 2 order by k").Check(testkit.Rows("1 0", "2 1")) + tk.MustQuery("select count(*) from ( select * from t2 group by a, b) A group by A.b").Check(testkit.Rows("3")) + tk.MustQuery("select count(*) from t1 where t1.a+100 > ( select count(*) from t2 where t1.a=t2.a and t1.b=t2.b) group by t1.b").Check(testkit.Rows("4")) + + failpoint.Enable("github.com/pingcap/tidb/executor/checkTotalMPPTasks", `return(3)`) + // all the data is related to one store, so there are three tasks. + tk.MustQuery("select avg(t.a) from t join t t1 on t.a = t1.a").Check(testkit.Rows("2.0000")) + failpoint.Disable("github.com/pingcap/tidb/executor/checkTotalMPPTasks") tk.MustExec("drop table if exists t") tk.MustExec("create table t (c1 decimal(8, 5) not null, c2 decimal(9, 5), c3 decimal(9, 4) , c4 decimal(8, 4) not null)") @@ -205,6 +212,26 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) { tk.MustQuery("select t1.c4 from t t1 join t t2 on t1.c4 = t2.c3 order by t1.c4").Check(testkit.Rows("1.0000", "1.0000", "1.0001")) } +func (s *tiflashTestSuite) TestInjectExtraProj(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bigint(20))") + tk.MustExec("alter table t set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t values (9223372036854775807)") + tk.MustExec("insert into t values (9223372036854775807)") + tk.MustExec("insert into t values (9223372036854775807)") + tk.MustExec("insert into t values (9223372036854775807)") + tk.MustExec("insert into t values (9223372036854775807)") + tk.MustExec("insert into t values (9223372036854775807)") + + tk.MustQuery("select avg(a) from t").Check(testkit.Rows("9223372036854775807.0000")) + tk.MustQuery("select avg(a), a from t group by a").Check(testkit.Rows("9223372036854775807.0000 9223372036854775807")) +} + func (s *tiflashTestSuite) TestPartitionTable(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -381,3 +408,30 @@ func (s *tiflashTestSuite) TestMppGoroutinesExitFromErrors(c *C) { c.Assert(failpoint.Disable(mppNonRootTaskError), IsNil) c.Assert(failpoint.Disable(hang), IsNil) } + +func (s *tiflashTestSuite) TestMppApply(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists x1") + tk.MustExec("create table x1(a int primary key, b int);") + tk.MustExec("alter table x1 set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "x1") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into x1 values(1, 1),(2, 10),(0,11);") + + tk.MustExec("create table x2(a int primary key, b int);") + tk.MustExec("alter table x2 set tiflash replica 1") + tb = testGetTableByName(c, tk.Se, "test", "x2") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into x2 values(1,2),(0,1),(2,-3);") + tk.MustExec("analyze table x1, x2;") + + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_allow_mpp=ON") + // table full scan with correlated filter + tk.MustQuery("select /*+ agg_to_cop(), hash_agg()*/ count(*) from x1 where a >= any (select a from x2 where x1.a = x2.a) order by 1;").Check(testkit.Rows("3")) + // table range scan with correlated access conditions + tk.MustQuery("select /*+ agg_to_cop(), hash_agg()*/ count(*) from x1 where b > any (select x2.a from x2 where x1.a = x2.a);").Check(testkit.Rows("2")) +} diff --git a/executor/update.go b/executor/update.go index ac0ebd57f0c19..b8c7e2a985142 100644 --- a/executor/update.go +++ b/executor/update.go @@ -50,6 +50,7 @@ type UpdateExec struct { // tblColPosInfos stores relationship between column ordinal to its table handle. // the columns ordinals is present in ordinal range format, @see plannercore.TblColPosInfos tblColPosInfos plannercore.TblColPosInfoSlice + assignFlag []int evalBuffer chunk.MutRow allAssignmentsAreConstant bool virtualAssignmentsOffset int @@ -58,24 +59,19 @@ type UpdateExec struct { stats *runtimeStatsWithSnapshot - handles []kv.Handle - updatable []bool - changed []bool - matches []bool - assignFlag []bool + handles []kv.Handle + tableUpdatable []bool + changed []bool + matches []bool } -// prepare `handles`, `updatable`, `changed` and `assignFlag` to avoid re-computations. -func (e *UpdateExec) prepare(ctx context.Context, schema *expression.Schema, row []types.Datum) (err error) { - e.assignFlag, err = plannercore.GetUpdateColumns(e.ctx, e.OrderedList, schema.Len()) - if err != nil { - return err - } +// prepare `handles`, `tableUpdatable`, `changed` to avoid re-computations. +func (e *UpdateExec) prepare(row []types.Datum) (err error) { if e.updatedRowKeys == nil { e.updatedRowKeys = make(map[int]*kv.HandleMap) } e.handles = e.handles[:0] - e.updatable = e.updatable[:0] + e.tableUpdatable = e.tableUpdatable[:0] e.changed = e.changed[:0] e.matches = e.matches[:0] for _, content := range e.tblColPosInfos { @@ -91,12 +87,15 @@ func (e *UpdateExec) prepare(ctx context.Context, schema *expression.Schema, row updatable := false flags := e.assignFlag[content.Start:content.End] for _, flag := range flags { - if flag { + if flag >= 0 { updatable = true break } } - e.updatable = append(e.updatable, updatable) + if e.unmatchedOuterRow(content, row) { + updatable = false + } + e.tableUpdatable = append(e.tableUpdatable, updatable) changed, ok := e.updatedRowKeys[content.Start].Get(handle) if ok { @@ -110,7 +109,7 @@ func (e *UpdateExec) prepare(ctx context.Context, schema *expression.Schema, row return nil } -func (e *UpdateExec) merge(ctx context.Context, row, newData []types.Datum, mergeGenerated bool) error { +func (e *UpdateExec) merge(row, newData []types.Datum, mergeGenerated bool) error { if e.mergedRowData == nil { e.mergedRowData = make(map[int64]*kv.HandleMap) } @@ -121,7 +120,7 @@ func (e *UpdateExec) merge(ctx context.Context, row, newData []types.Datum, merg // No need to merge if not multi-updated continue } - if !e.updatable[i] { + if !e.tableUpdatable[i] { // If there's nothing to update, we can just skip current row continue } @@ -145,7 +144,7 @@ func (e *UpdateExec) merge(ctx context.Context, row, newData []types.Datum, merg continue } mergedData[i].Copy(&oldData[i]) - if flag { + if flag >= 0 { newTableData[i].Copy(&mergedData[i]) } else { mergedData[i].Copy(&newTableData[i]) @@ -161,8 +160,12 @@ func (e *UpdateExec) merge(ctx context.Context, row, newData []types.Datum, merg func (e *UpdateExec) exec(ctx context.Context, schema *expression.Schema, row, newData []types.Datum) error { defer trace.StartRegion(ctx, "UpdateExec").End() + bAssignFlag := make([]bool, len(e.assignFlag)) + for i, flag := range e.assignFlag { + bAssignFlag[i] = flag >= 0 + } for i, content := range e.tblColPosInfos { - if !e.updatable[i] { + if !e.tableUpdatable[i] { // If there's nothing to update, we can just skip current row continue } @@ -179,7 +182,7 @@ func (e *UpdateExec) exec(ctx context.Context, schema *expression.Schema, row, n oldData := row[content.Start:content.End] newTableData := newData[content.Start:content.End] - flags := e.assignFlag[content.Start:content.End] + flags := bAssignFlag[content.Start:content.End] // Update row changed, err1 := updateRecord(ctx, e.ctx, handle, oldData, newTableData, flags, tbl, false, e.memTracker) @@ -198,14 +201,15 @@ func (e *UpdateExec) exec(ctx context.Context, schema *expression.Schema, row, n return nil } -// canNotUpdate checks the handle of a record to decide whether that record +// unmatchedOuterRow checks the tableCols of a record to decide whether that record // can not be updated. The handle is NULL only when it is the inner side of an // outer join: the outer row can not match any inner rows, and in this scenario // the inner handle field is filled with a NULL value. // // This fixes: https://github.com/pingcap/tidb/issues/7176. -func (e *UpdateExec) canNotUpdate(handle types.Datum) bool { - return handle.IsNull() +func (e *UpdateExec) unmatchedOuterRow(tblPos plannercore.TblColPosInfo, waitUpdateRow []types.Datum) bool { + firstHandleIdx := tblPos.HandleCols.GetCol(0) + return waitUpdateRow[firstHandleIdx.Index].IsNull() } // Next implements the Executor Next interface. @@ -264,7 +268,7 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { chunkRow := chk.GetRow(rowIdx) datumRow := chunkRow.GetDatumRow(fields) // precomputes handles - if err := e.prepare(ctx, e.children[0].Schema(), datumRow); err != nil { + if err := e.prepare(datumRow); err != nil { return 0, err } // compose non-generated columns @@ -273,7 +277,7 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { return 0, err } // merge non-generated columns - if err := e.merge(ctx, datumRow, newRow, false); err != nil { + if err := e.merge(datumRow, newRow, false); err != nil { return 0, err } if e.virtualAssignmentsOffset < len(e.OrderedList) { @@ -283,7 +287,7 @@ func (e *UpdateExec) updateRows(ctx context.Context) (int, error) { return 0, err } // merge generated columns - if err := e.merge(ctx, datumRow, newRow, true); err != nil { + if err := e.merge(datumRow, newRow, true); err != nil { return 0, err } } @@ -317,11 +321,10 @@ func (e *UpdateExec) handleErr(colName model.CIStr, rowIdx int, err error) error func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols []*table.Column) ([]types.Datum, error) { newRowData := types.CloneRow(oldRow) for _, assign := range e.OrderedList { - handleIdx, handleFound := e.tblColPosInfos.FindHandle(assign.Col.Index) - if handleFound && e.canNotUpdate(oldRow[handleIdx]) { + tblIdx := e.assignFlag[assign.Col.Index] + if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { continue } - con := assign.Expr.(*expression.Constant) val, err := con.Eval(emptyRow) if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { @@ -346,8 +349,8 @@ func (e *UpdateExec) composeNewRow(rowIdx int, oldRow []types.Datum, cols []*tab newRowData := types.CloneRow(oldRow) e.evalBuffer.SetDatums(newRowData...) for _, assign := range e.OrderedList[:e.virtualAssignmentsOffset] { - handleIdx, handleFound := e.tblColPosInfos.FindHandle(assign.Col.Index) - if handleFound && e.canNotUpdate(oldRow[handleIdx]) { + tblIdx := e.assignFlag[assign.Col.Index] + if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { continue } val, err := assign.Expr.Eval(e.evalBuffer.ToRow()) @@ -375,8 +378,8 @@ func (e *UpdateExec) composeGeneratedColumns(rowIdx int, newRowData []types.Datu } e.evalBuffer.SetDatums(newRowData...) for _, assign := range e.OrderedList[e.virtualAssignmentsOffset:] { - handleIdx, handleFound := e.tblColPosInfos.FindHandle(assign.Col.Index) - if handleFound && e.canNotUpdate(newRowData[handleIdx]) { + tblIdx := e.assignFlag[assign.Col.Index] + if tblIdx >= 0 && !e.tableUpdatable[tblIdx] { continue } val, err := assign.Expr.Eval(e.evalBuffer.ToRow()) diff --git a/executor/update_test.go b/executor/update_test.go index 1811d0f913ffe..0509994da2da6 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/util/testkit" @@ -335,7 +336,7 @@ type testSuite11 struct { func (s *testSuite11) TestUpdateClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists t`) tk.MustExec(`create table t(id varchar(200) primary key, v int)`) @@ -387,7 +388,7 @@ func (s *testSuite11) TestUpdateClusterIndex(c *C) { func (s *testSuite11) TestDeleteClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists t`) tk.MustExec(`create table t(id varchar(200) primary key, v int)`) @@ -422,7 +423,7 @@ func (s *testSuite11) TestDeleteClusterIndex(c *C) { func (s *testSuite11) TestReplaceClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec(`drop table if exists rt1pk`) tk.MustExec(`create table rt1pk(id varchar(200) primary key, v int)`) @@ -448,11 +449,12 @@ func (s *testSuite11) TestReplaceClusterIndex(c *C) { func (s *testSuite11) TestPessimisticUpdatePKLazyCheck(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) - s.testUpdatePKLazyCheck(c, tk, true) - s.testUpdatePKLazyCheck(c, tk, false) + s.testUpdatePKLazyCheck(c, tk, variable.ClusteredIndexDefModeOn) + s.testUpdatePKLazyCheck(c, tk, variable.ClusteredIndexDefModeOff) + s.testUpdatePKLazyCheck(c, tk, variable.ClusteredIndexDefModeIntOnly) } -func (s *testSuite11) testUpdatePKLazyCheck(c *C, tk *testkit.TestKit, clusteredIndex bool) { +func (s *testSuite11) testUpdatePKLazyCheck(c *C, tk *testkit.TestKit, clusteredIndex variable.ClusteredIndexDefMode) { tk.Se.GetSessionVars().EnableClusteredIndex = clusteredIndex tk.MustExec(`drop table if exists upk`) tk.MustExec(`create table upk (a int, b int, c int, primary key (a, b))`) @@ -517,3 +519,12 @@ func (s *testPointGetSuite) TestIssue21447(c *C) { tk1.MustQuery("select * from t1 where id in (1, 2) for update").Check(testkit.Rows("1 xyz")) tk1.MustExec("commit") } + +func (s *testSuite11) TestIssue23553(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`drop table if exists tt`) + tk.MustExec(`create table tt (m0 varchar(64), status tinyint not null)`) + tk.MustExec(`insert into tt values('1',0),('1',0),('1',0)`) + tk.MustExec(`update tt a inner join (select m0 from tt where status!=1 group by m0 having count(*)>1) b on a.m0=b.m0 set a.status=1`) +} diff --git a/executor/write_test.go b/executor/write_test.go index 1721d726179b1..a6ef34ce5cf0a 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" @@ -3893,7 +3894,7 @@ func (s *testSerialSuite) TestIssue20840(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t1") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t1 (i varchar(20) unique key) collate=utf8mb4_general_ci") tk.MustExec("insert into t1 values ('a')") tk.MustExec("replace into t1 values ('A')") diff --git a/expression/integration_test.go b/expression/integration_test.go index ea3fbbe150373..fd683d4696fde 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -5396,7 +5396,7 @@ func (s *testIntegrationSuite) TestIssue16973(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t1") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t1(id varchar(36) not null primary key, org_id varchar(36) not null, " + "status tinyint default 1 not null, ns varchar(36) default '' not null);") tk.MustExec("create table t2(id varchar(36) not null primary key, order_id varchar(36) not null, " + @@ -6980,7 +6980,7 @@ func (s *testIntegrationSerialSuite) TestNewCollationCheckClusterIndexTable(c *C tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(name char(255) primary key, b int, c int, index idx(name), unique index uidx(name))") tk.MustExec("insert into t values(\"aaaa\", 1, 1), (\"bbb\", 2, 2), (\"ccc\", 3, 3)") tk.MustExec("admin check table t") @@ -7088,7 +7088,7 @@ func (s *testIntegrationSerialSuite) TestNewCollationWithClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(d double primary key, a int, name varchar(255), index idx(name(2)), index midx(a, name))") tk.MustExec("insert into t values(2.11, 1, \"aa\"), (-1, 0, \"abcd\"), (9.99, 0, \"aaaa\")") tk.MustQuery("select d from t use index(idx) where name=\"aa\"").Check(testkit.Rows("2.11")) @@ -7946,7 +7946,7 @@ func (s *testIntegrationSerialSuite) TestClusteredIndexAndNewCollationIndexEncod tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(a int, b char(10) collate utf8mb4_bin, c char(10) collate utf8mb4_general_ci," + "d varchar(10) collate utf8mb4_bin, e varchar(10) collate utf8mb4_general_ci, f char(10) collate utf8mb4_unicode_ci, g varchar(10) collate utf8mb4_unicode_ci, " + "primary key(a, b, c, d, e, f, g), key a(a), unique key ua(a), key b(b), unique key ub(b), key c(c), unique key uc(c)," + @@ -8092,7 +8092,7 @@ func (s *testIntegrationSerialSuite) TestClusteredIndexAndNewCollation(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("CREATE TABLE `t` (" + "`a` char(10) COLLATE utf8mb4_unicode_ci NOT NULL," + "`b` char(20) COLLATE utf8mb4_general_ci NOT NULL," + @@ -8523,7 +8523,7 @@ func (s *testIntegrationSerialSuite) TestIssue20876(c *C) { defer collate.SetNewCollationEnabledForTest(false) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t;") tk.MustExec("CREATE TABLE `t` (" + " `a` char(10) COLLATE utf8mb4_unicode_ci NOT NULL," + @@ -8934,3 +8934,15 @@ func (s *testIntegrationSuite) TestJiraSetInnoDBDefaultRowFormat(c *C) { tk.MustExec("set global innodb_default_row_format = dynamic") tk.MustExec("set global innodb_default_row_format = 'dynamic'") } + +func (s *testIntegrationSerialSuite) TestCollationForBinaryLiteral(c *C) { + tk := testkit.NewTestKit(c, s.store) + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE t (`COL1` tinyblob NOT NULL, `COL2` binary(1) NOT NULL, `COL3` bigint(11) NOT NULL, PRIMARY KEY (`COL1`(5),`COL2`,`COL3`) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("insert into t values(0x1E,0xEC,6966939640596047133);") + tk.MustQuery("select * from t where col1 not in (0x1B,0x20) order by col1").Check(testkit.Rows("\x1e \xec 6966939640596047133")) + tk.MustExec("drop table t") +} diff --git a/kv/interface_mock_test.go b/kv/interface_mock_test.go index 9251091f23b8c..ef49933526ac1 100644 --- a/kv/interface_mock_test.go +++ b/kv/interface_mock_test.go @@ -168,7 +168,7 @@ func (*mockTxn) IsPessimistic() bool { func (s *mockStorage) GetSnapshot(ver Version) Snapshot { return &mockSnapshot{ - store: newMemDB(), + store: newMockMap(), } } @@ -223,7 +223,7 @@ func newMockStorage() Storage { } type mockSnapshot struct { - store MemBuffer + store Retriever } func (s *mockSnapshot) Get(ctx context.Context, k Key) ([]byte, error) { diff --git a/kv/kv.go b/kv/kv.go index 1cb5265ec5000..5ba7c30dce171 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -49,6 +49,9 @@ var ( TxnTotalSizeLimit uint64 = config.DefTxnTotalSizeLimit ) +// FlagsOp describes KeyFlags modify operation. TODO:remove it when br is ready +type FlagsOp = tikvstore.FlagsOp + // Getter is the interface for the Get method. type Getter interface { // Get gets the value for key k from kv store. @@ -97,15 +100,6 @@ type RetrieverMutator interface { Mutator } -// MemBufferIterator is an Iterator with KeyFlags related functions. -type MemBufferIterator interface { - Iterator - HasValue() bool - Flags() KeyFlags - UpdateFlags(...FlagsOp) - Handle() MemKeyHandle -} - // MemBuffer is an in-memory kv collection, can be used to buffer write operations. type MemBuffer interface { RetrieverMutator @@ -119,26 +113,13 @@ type MemBuffer interface { RUnlock() // GetFlags returns the latest flags associated with key. - GetFlags(Key) (KeyFlags, error) - // IterWithFlags returns a MemBufferIterator. - IterWithFlags(k Key, upperBound Key) MemBufferIterator - // IterReverseWithFlags returns a reversed MemBufferIterator. - IterReverseWithFlags(k Key) MemBufferIterator + GetFlags(Key) (tikvstore.KeyFlags, error) // SetWithFlags put key-value into the last active staging buffer with the given KeyFlags. - SetWithFlags(Key, []byte, ...FlagsOp) error + SetWithFlags(Key, []byte, ...tikvstore.FlagsOp) error // UpdateFlags update the flags associated with key. - UpdateFlags(Key, ...FlagsOp) + UpdateFlags(Key, ...tikvstore.FlagsOp) // DeleteWithFlags delete key with the given KeyFlags - DeleteWithFlags(Key, ...FlagsOp) error - - GetKeyByHandle(MemKeyHandle) []byte - GetValueByHandle(MemKeyHandle) ([]byte, bool) - - // Reset reset the MemBuffer to initial states. - Reset() - // DiscardValues releases the memory used by all values. - // NOTE: any operation need value will panic after this function. - DiscardValues() + DeleteWithFlags(Key, ...tikvstore.FlagsOp) error // Staging create a new staging buffer inside the MemBuffer. // Subsequent writes will be temporarily stored in this new staging buffer. @@ -150,21 +131,15 @@ type MemBuffer interface { // If the changes are not published by `Release`, they will be discarded. Cleanup(StagingHandle) // InspectStage used to inspect the value updates in the given stage. - InspectStage(StagingHandle, func(Key, KeyFlags, []byte)) + InspectStage(StagingHandle, func(Key, tikvstore.KeyFlags, []byte)) - // SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. - SelectValueHistory(key Key, predicate func(value []byte) bool) ([]byte, error) // SnapshotGetter returns a Getter for a snapshot of MemBuffer. SnapshotGetter() Getter // SnapshotIter returns a Iterator for a snapshot of MemBuffer. SnapshotIter(k, upperbound Key) Iterator - // Size returns sum of keys and values length. - Size() int // Len returns the number of entries in the DB. Len() int - // Dirty returns whether the root staging buffer is updated. - Dirty() bool } // Transaction defines the interface for operations inside a Transaction. diff --git a/kv/union_store.go b/kv/union_store.go index bff41ff5dd863..e625cd2432c5b 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -14,8 +14,6 @@ package kv import ( - "context" - tikvstore "github.com/pingcap/tidb/store/tikv/kv" ) @@ -62,102 +60,3 @@ type Options interface { // Get gets an option value. Get(opt int) (v interface{}, ok bool) } - -// unionStore is an in-memory Store which contains a buffer for write and a -// snapshot for read. -type unionStore struct { - memBuffer *memdb - snapshot Snapshot - opts options -} - -// NewUnionStore builds a new unionStore. -func NewUnionStore(snapshot Snapshot) UnionStore { - return &unionStore{ - snapshot: snapshot, - memBuffer: newMemDB(), - opts: make(map[int]interface{}), - } -} - -// GetMemBuffer return the MemBuffer binding to this unionStore. -func (us *unionStore) GetMemBuffer() MemBuffer { - return us.memBuffer -} - -// Get implements the Retriever interface. -func (us *unionStore) Get(ctx context.Context, k Key) ([]byte, error) { - v, err := us.memBuffer.Get(ctx, k) - if IsErrNotFound(err) { - v, err = us.snapshot.Get(ctx, k) - } - if err != nil { - return v, err - } - if len(v) == 0 { - return nil, ErrNotExist - } - return v, nil -} - -// Iter implements the Retriever interface. -func (us *unionStore) Iter(k Key, upperBound Key) (Iterator, error) { - bufferIt, err := us.memBuffer.Iter(k, upperBound) - if err != nil { - return nil, err - } - retrieverIt, err := us.snapshot.Iter(k, upperBound) - if err != nil { - return nil, err - } - return NewUnionIter(bufferIt, retrieverIt, false) -} - -// IterReverse implements the Retriever interface. -func (us *unionStore) IterReverse(k Key) (Iterator, error) { - bufferIt, err := us.memBuffer.IterReverse(k) - if err != nil { - return nil, err - } - retrieverIt, err := us.snapshot.IterReverse(k) - if err != nil { - return nil, err - } - return NewUnionIter(bufferIt, retrieverIt, true) -} - -// HasPresumeKeyNotExists gets the key exist error info for the lazy check. -func (us *unionStore) HasPresumeKeyNotExists(k Key) bool { - flags, err := us.memBuffer.GetFlags(k) - if err != nil { - return false - } - return flags.HasPresumeKeyNotExists() -} - -// DeleteKeyExistErrInfo deletes the key exist error info for the lazy check. -func (us *unionStore) UnmarkPresumeKeyNotExists(k Key) { - us.memBuffer.UpdateFlags(k, DelPresumeKeyNotExists) -} - -// SetOption implements the unionStore SetOption interface. -func (us *unionStore) SetOption(opt int, val interface{}) { - us.opts[opt] = val -} - -// DelOption implements the unionStore DelOption interface. -func (us *unionStore) DelOption(opt int) { - delete(us.opts, opt) -} - -// GetOption implements the unionStore GetOption interface. -func (us *unionStore) GetOption(opt int) interface{} { - return us.opts[opt] -} - -type options map[int]interface{} - -func (opts options) Get(opt int) (interface{}, bool) { - v, ok := opts[opt] - return v, ok -} diff --git a/kv/utils_test.go b/kv/utils_test.go index b085dda9ae164..03c77a56e8705 100644 --- a/kv/utils_test.go +++ b/kv/utils_test.go @@ -25,8 +25,59 @@ var _ = Suite(testUtilsSuite{}) type testUtilsSuite struct { } +type mockMap struct { + index []Key + value [][]byte +} + +func newMockMap() *mockMap { + return &mockMap{ + index: make([]Key, 0), + value: make([][]byte, 0), + } +} + +func (s *mockMap) Iter(k Key, upperBound Key) (Iterator, error) { + return nil, nil +} +func (s *mockMap) IterReverse(k Key) (Iterator, error) { + return nil, nil +} + +func (s *mockMap) Get(ctx context.Context, k Key) ([]byte, error) { + for i, key := range s.index { + if key.Cmp(k) == 0 { + return s.value[i], nil + } + } + return nil, nil +} + +func (s *mockMap) Set(k Key, v []byte) error { + for i, key := range s.index { + if key.Cmp(k) == 0 { + s.value[i] = v + return nil + } + } + s.index = append(s.index, k) + s.value = append(s.value, v) + return nil +} + +func (s *mockMap) Delete(k Key) error { + for i, key := range s.index { + if key.Cmp(k) == 0 { + s.index = append(s.index[:i], s.index[i+1:]...) + s.value = append(s.value[:i], s.value[i+1:]...) + return nil + } + } + return nil +} + func (s testUtilsSuite) TestIncInt64(c *C) { - mb := newMemDB() + mb := newMockMap() key := Key("key") v, err := IncInt64(mb, key, 1) c.Check(err, IsNil) @@ -51,7 +102,7 @@ func (s testUtilsSuite) TestIncInt64(c *C) { } func (s testUtilsSuite) TestGetInt64(c *C) { - mb := newMemDB() + mb := newMockMap() key := Key("key") v, err := GetInt64(context.TODO(), mb, key) c.Check(v, Equals, int64(0)) diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index ed309de916136..0f085282fe103 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/store/mockstore" @@ -911,7 +912,7 @@ func (s *testAnalyzeSuite) TestIndexEqualUnknown(c *C) { }() testKit.MustExec("use test") testKit.MustExec("drop table if exists t, t1") - testKit.Se.GetSessionVars().EnableClusteredIndex = false + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly testKit.MustExec("CREATE TABLE t(a bigint(20) NOT NULL, b bigint(20) NOT NULL, c bigint(20) NOT NULL, PRIMARY KEY (a,c,b), KEY (b))") err = s.loadTableStats("analyzeSuiteTestIndexEqualUnknownT.json", dom) c.Assert(err, IsNil) diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 9680bc4e5646f..80ba0201c99ca 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -786,6 +786,8 @@ type Update struct { // Used when partition sets are given. // e.g. update t partition(p0) set a = 1; PartitionedTable []table.PartitionedTable + + tblID2Table map[int64]table.Table } // Delete represents a delete plan. diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index bba801ee275e5..0db7a23086fe9 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1666,7 +1666,7 @@ func (p *LogicalJoin) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P } joins := make([]PhysicalPlan, 0, 8) canPushToTiFlash := p.canPushToCop(kv.TiFlash) - if p.ctx.GetSessionVars().AllowMPPExecution && !collate.NewCollationEnabled() && canPushToTiFlash { + if p.ctx.GetSessionVars().AllowMPPExecution && canPushToTiFlash { if p.shouldUseMPPBCJ() { mppJoins := p.tryToGetMppHashJoin(prop, true) if (p.preferJoinType & preferBCJoin) > 0 { @@ -2306,7 +2306,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert } // TODO: permute various partition columns from group-by columns // 1-phase agg - // If there are no available parititon cols, but still have group by items, that means group by items are all expressions or constants. + // If there are no available partition cols, but still have group by items, that means group by items are all expressions or constants. // To avoid mess, we don't do any one-phase aggregation in this case. if len(partitionCols) != 0 { childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: partitionCols, CanAddEnforcer: true} @@ -2321,6 +2321,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) agg.MppRunMode = Mpp2Phase + agg.MppPartitionCols = partitionCols hashAggs = append(hashAggs, agg) // agg runs on TiDB with a partial agg on TiFlash if possible @@ -2355,7 +2356,7 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy taskTypes = append(taskTypes, property.CopTiFlashLocalReadTaskType) } canPushDownToTiFlash := la.canPushToCop(kv.TiFlash) - canPushDownToMPP := la.ctx.GetSessionVars().AllowMPPExecution && !collate.NewCollationEnabled() && la.checkCanPushDownToMPP() && canPushDownToTiFlash + canPushDownToMPP := la.ctx.GetSessionVars().AllowMPPExecution && la.checkCanPushDownToMPP() && canPushDownToTiFlash if la.HasDistinct() { // TODO: remove after the cost estimation of distinct pushdown is implemented. if !la.ctx.GetSessionVars().AllowDistinctAggPushDown { diff --git a/planner/core/fragment.go b/planner/core/fragment.go index 1ed688a1a63f7..ef436309d7c16 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -67,6 +67,37 @@ func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*kv.MP return rootTasks, nil } +type mppAddr struct { + addr string +} + +func (m *mppAddr) GetAddress() string { + return m.addr +} + +// for the task without table scan, we construct tasks according to the children's tasks. +// That's for avoiding assigning to the failed node repeatly. We assumes that the chilren node must be workable. +func (e *mppTaskGenerator) constructMPPTasksByChildrenTasks(tasks []*kv.MPPTask) []*kv.MPPTask { + addressMap := make(map[string]struct{}) + newTasks := make([]*kv.MPPTask, 0, len(tasks)) + for _, task := range tasks { + addr := task.Meta.GetAddress() + _, ok := addressMap[addr] + if !ok { + *e.allocTaskID++ + mppTask := &kv.MPPTask{ + Meta: &mppAddr{addr: addr}, + ID: *e.allocTaskID, + StartTs: e.startTS, + TableID: -1, + } + newTasks = append(newTasks, mppTask) + addressMap[addr] = struct{}{} + } + } + return newTasks +} + func (f *Fragment) init(p PhysicalPlan) error { switch x := p.(type) { case *PhysicalTableScan: @@ -107,7 +138,11 @@ func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender if f.TableScan != nil { tasks, err = e.constructMPPTasksImpl(context.Background(), f.TableScan) } else { - tasks, err = e.constructMPPTasksImpl(context.Background(), nil) + childrenTasks := make([]*kv.MPPTask, 0) + for _, r := range f.ExchangeReceivers { + childrenTasks = append(childrenTasks, r.Tasks...) + } + tasks = e.constructMPPTasksByChildrenTasks(childrenTasks) } if err != nil { return nil, errors.Trace(err) @@ -154,39 +189,47 @@ func partitionPruning(ctx sessionctx.Context, tbl table.PartitionedTable, conds // single physical table means a table without partitions or a single partition in a partition table. func (e *mppTaskGenerator) constructMPPTasksImpl(ctx context.Context, ts *PhysicalTableScan) ([]*kv.MPPTask, error) { - if ts != nil { - splitedRanges, _ := distsql.SplitRangesBySign(ts.Ranges, false, false, ts.Table.IsCommonHandle) - if ts.Table.GetPartitionInfo() != nil { - tmp, _ := e.is.TableByID(ts.Table.ID) - tbl := tmp.(table.PartitionedTable) - partitions, err := partitionPruning(e.ctx, tbl, ts.PartitionInfo.PruningConds, ts.PartitionInfo.PartitionNames, ts.PartitionInfo.Columns, ts.PartitionInfo.ColumnNames) + // update ranges according to correlated columns in access conditions like in the Open() of TableReaderExecutor + for _, cond := range ts.AccessCondition { + if len(expression.ExtractCorColumns(cond)) > 0 { + _, err := ts.ResolveCorrelatedColumns() if err != nil { - return nil, errors.Trace(err) - } - var ret []*kv.MPPTask - for _, p := range partitions { - pid := p.GetPhysicalID() - meta := p.Meta() - kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && ts.Table.IsCommonHandle, splitedRanges, nil) - if err != nil { - return nil, errors.Trace(err) - } - tasks, err := e.constructMPPTasksForSinglePartitionTable(ctx, kvRanges, pid) - if err != nil { - return nil, errors.Trace(err) - } - ret = append(ret, tasks...) + return nil, err } - return ret, nil + break } + } - kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{ts.Table.ID}, ts.Table.IsCommonHandle, splitedRanges, nil) + splitedRanges, _ := distsql.SplitRangesBySign(ts.Ranges, false, false, ts.Table.IsCommonHandle) + if ts.Table.GetPartitionInfo() != nil { + tmp, _ := e.is.TableByID(ts.Table.ID) + tbl := tmp.(table.PartitionedTable) + partitions, err := partitionPruning(e.ctx, tbl, ts.PartitionInfo.PruningConds, ts.PartitionInfo.PartitionNames, ts.PartitionInfo.Columns, ts.PartitionInfo.ColumnNames) if err != nil { return nil, errors.Trace(err) } - return e.constructMPPTasksForSinglePartitionTable(ctx, kvRanges, ts.Table.ID) + var ret []*kv.MPPTask + for _, p := range partitions { + pid := p.GetPhysicalID() + meta := p.Meta() + kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{pid}, meta != nil && ts.Table.IsCommonHandle, splitedRanges, nil) + if err != nil { + return nil, errors.Trace(err) + } + tasks, err := e.constructMPPTasksForSinglePartitionTable(ctx, kvRanges, pid) + if err != nil { + return nil, errors.Trace(err) + } + ret = append(ret, tasks...) + } + return ret, nil + } + + kvRanges, err := distsql.TableHandleRangesToKVRanges(e.ctx.GetSessionVars().StmtCtx, []int64{ts.Table.ID}, ts.Table.IsCommonHandle, splitedRanges, nil) + if err != nil { + return nil, errors.Trace(err) } - return e.constructMPPTasksForSinglePartitionTable(ctx, nil, -1) + return e.constructMPPTasksForSinglePartitionTable(ctx, kvRanges, ts.Table.ID) } func (e *mppTaskGenerator) constructMPPTasksForSinglePartitionTable(ctx context.Context, kvRanges []kv.KeyRange, tableID int64) ([]*kv.MPPTask, error) { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 5b0235cccffb4..c08d01c62861a 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -626,13 +626,58 @@ func (s *testIntegrationSerialSuite) TestJoinNotSupportedByTiFlash(c *C) { } } -func (s *testIntegrationSerialSuite) TestMPPNotSupportedInNewCollation(c *C) { +func (s *testIntegrationSerialSuite) TestMPPWithHashExchangeUnderNewCollation(c *C) { defer collate.SetNewCollationEnabledForTest(false) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists table_1") - tk.MustExec("create table table_1(id int not null, value int)") - tk.MustExec("insert into table_1 values(1,1),(2,2)") + tk.MustExec("create table table_1(id int not null, value char(10))") + tk.MustExec("insert into table_1 values(1,'1'),(2,'2')") + tk.MustExec("analyze table table_1") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "table_1" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + collate.SetNewCollationEnabledForTest(true) + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") + tk.MustExec("set @@session.tidb_allow_mpp = 1") + tk.MustExec("set @@session.tidb_opt_broadcast_join = 0") + tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 0") + tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 0") + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + +func (s *testIntegrationSerialSuite) TestMPPWithBroadcastExchangeUnderNewCollation(c *C) { + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists table_1") + tk.MustExec("create table table_1(id int not null, value char(10))") + tk.MustExec("insert into table_1 values(1,'1'),(2,'2')") tk.MustExec("analyze table table_1") // Create virtual tiflash replica info. @@ -1018,7 +1063,7 @@ func (s *testIntegrationSuite) TestMaxMinEliminate(c *C) { tk.MustExec("use test") tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key)") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table cluster_index_t(a int, b int, c int, primary key (a, b));") var input []string @@ -1053,7 +1098,7 @@ func (s *testIntegrationSuite) TestIndexJoinUniqueCompositeIndex(c *C) { tk.MustExec("use test") tk.MustExec("drop table if exists t1, t2") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t1(a int not null, c int not null)") tk.MustExec("create table t2(a int not null, b int not null, c int not null, primary key(a,b))") tk.MustExec("insert into t1 values(1,1)") @@ -1792,7 +1837,7 @@ func (s *testIntegrationSuite) TestIssue16935(c *C) { func (s *testIntegrationSuite) TestAccessPathOnClusterIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t1") tk.MustExec("create table t1 (a int, b varchar(20), c decimal(40,10), d int, primary key(a,b), key(c))") tk.MustExec(`insert into t1 values (1,"111",1.1,11), (2,"222",2.2,12), (3,"333",3.3,13)`) @@ -1821,7 +1866,7 @@ func (s *testIntegrationSuite) TestClusterIndexUniqueDoubleRead(c *C) { tk.MustExec("create database cluster_idx_unique_double_read;") tk.MustExec("use cluster_idx_unique_double_read;") defer tk.MustExec("drop database cluster_idx_unique_double_read;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t") tk.MustExec("create table t (a varchar(64), b varchar(64), uk int, v int, primary key(a, b), unique key uuk(uk));") @@ -1832,7 +1877,7 @@ func (s *testIntegrationSuite) TestClusterIndexUniqueDoubleRead(c *C) { func (s *testIntegrationSuite) TestIndexJoinOnClusteredIndex(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t1") tk.MustExec("create table t (a int, b varchar(20), c decimal(40,10), d int, primary key(a,b), key(c))") tk.MustExec(`insert into t values (1,"111",1.1,11), (2,"222",2.2,12), (3,"333",3.3,13)`) @@ -1859,7 +1904,7 @@ func (s *testIntegrationSerialSuite) TestIssue18984(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t, t2") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t(a int, b int, c int, primary key(a, b))") tk.MustExec("create table t2(a int, b int, c int, d int, primary key(a,b), index idx(c))") tk.MustExec("insert into t values(1,1,1), (2,2,2), (3,3,3)") @@ -2005,7 +2050,7 @@ func (s *testIntegrationSerialSuite) Test19942(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("CREATE TABLE test.`t` (" + " `a` int(11) NOT NULL," + " `b` varchar(10) COLLATE utf8_general_ci NOT NULL," + diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 971e29d0f1a91..b2da4c7fa18fe 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4150,8 +4150,7 @@ type TblColPosInfo struct { // Start and End represent the ordinal range [Start, End) of the consecutive columns. Start, End int // HandleOrdinal represents the ordinal of the handle column. - HandleCols HandleCols - IsCommonHandle bool // TODO: fix redesign update join table and remove me! + HandleCols HandleCols } // TblColPosInfoSlice attaches the methods of sort.Interface to []TblColPosInfos sorting in increasing order. @@ -4172,8 +4171,8 @@ func (c TblColPosInfoSlice) Less(i, j int) bool { return c[i].Start < c[j].Start } -// FindHandle finds the ordinal of the corresponding handle column. -func (c TblColPosInfoSlice) FindHandle(colOrdinal int) (int, bool) { +// FindTblIdx finds the ordinal of the corresponding access column. +func (c TblColPosInfoSlice) FindTblIdx(colOrdinal int) (int, bool) { if len(c) == 0 { return 0, false } @@ -4183,11 +4182,7 @@ func (c TblColPosInfoSlice) FindHandle(colOrdinal int) (int, bool) { if rangeBehindOrdinal == 0 { return 0, false } - if c[rangeBehindOrdinal-1].IsCommonHandle { - // TODO: fix redesign update join table to fix me. - return 0, false - } - return c[rangeBehindOrdinal-1].HandleCols.GetCol(0).Index, true + return rangeBehindOrdinal - 1, true } // buildColumns2Handle builds columns to handle mapping. @@ -4212,8 +4207,7 @@ func buildColumns2Handle( return nil, err } end := offset + tblLen - cols2Handles = append(cols2Handles, TblColPosInfo{tblID, offset, end, handleCol, tbl.Meta().IsCommonHandle}) - // TODO: fix me for cluster index + cols2Handles = append(cols2Handles, TblColPosInfo{tblID, offset, end, handleCol}) } } sort.Sort(cols2Handles) @@ -4337,46 +4331,28 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( tblID2table[id], _ = b.is.TableByID(id) } updt.TblColPosInfos, err = buildColumns2Handle(updt.OutputNames(), tblID2Handle, tblID2table, true) - if err == nil { - err = checkUpdateList(b.ctx, tblID2table, updt) - } updt.PartitionedTable = b.partitionedTable + updt.tblID2Table = tblID2table return updt, err } -// GetUpdateColumns gets the columns of updated lists. -func GetUpdateColumns(ctx sessionctx.Context, orderedList []*expression.Assignment, schemaLen int) ([]bool, error) { - assignFlag := make([]bool, schemaLen) - for _, v := range orderedList { - if !ctx.GetSessionVars().AllowWriteRowID && v.Col.ID == model.ExtraHandleID { - return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") - } - idx := v.Col.Index - assignFlag[idx] = true - } - return assignFlag, nil -} - type tblUpdateInfo struct { name string pkUpdated bool } -func checkUpdateList(ctx sessionctx.Context, tblID2table map[int64]table.Table, updt *Update) error { - assignFlags, err := GetUpdateColumns(ctx, updt.OrderedList, updt.SelectPlan.Schema().Len()) - if err != nil { - return err - } +// CheckUpdateList checks all related columns in updatable state. +func CheckUpdateList(assignFlags []int, updt *Update) error { updateFromOtherAlias := make(map[int64]tblUpdateInfo) for _, content := range updt.TblColPosInfos { - tbl := tblID2table[content.TblID] + tbl := updt.tblID2Table[content.TblID] flags := assignFlags[content.Start:content.End] var update, updatePK bool for i, col := range tbl.WritableCols() { - if flags[i] && col.State != model.StatePublic { + if flags[i] >= 0 && col.State != model.StatePublic { return ErrUnknownColumn.GenWithStackByArgs(col.Name, clauseMsg[fieldList]) } - if flags[i] { + if flags[i] >= 0 { update = true if mysql.HasPriKeyFlag(col.Flag) { updatePK = true diff --git a/planner/core/partition_pruner_test.go b/planner/core/partition_pruner_test.go index 16f534e926fde..4071089aa7597 100644 --- a/planner/core/partition_pruner_test.go +++ b/planner/core/partition_pruner_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testutil" @@ -68,7 +69,7 @@ func (s *testPartitionPruneSuit) TestHashPartitionPruner(c *C) { tk.MustExec("create database test_partition") tk.MustExec("use test_partition") tk.MustExec("drop table if exists t1, t2;") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t2(id int, a int, b int, primary key(id, a)) partition by hash(id + a) partitions 10;") tk.MustExec("create table t1(id int primary key, a int, b int) partition by hash(id) partitions 10;") tk.MustExec("create table t3(id int, a int, b int, primary key(id, a)) partition by hash(id) partitions 10;") @@ -97,7 +98,7 @@ func (s *testPartitionPruneSuit) TestListPartitionPruner(c *C) { tk.MustExec("drop database if exists test_partition;") tk.MustExec("create database test_partition") tk.MustExec("use test_partition") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("set @@session.tidb_enable_list_partition = ON") tk.MustExec("create table t1 (id int, a int, b int ) partition by list ( a ) (partition p0 values in (1,2,3,4,5), partition p1 values in (6,7,8,9,10,null));") tk.MustExec("create table t2 (a int, id int, b int) partition by list (a*3 + b - 2*a - b) (partition p0 values in (1,2,3,4,5), partition p1 values in (6,7,8,9,10,null));") diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 137fd54227c4c..43304971b4680 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tipb/go-tipb" @@ -519,6 +520,35 @@ func (ts *PhysicalTableScan) IsPartition() (bool, int64) { return ts.isPartition, ts.physicalTableID } +// ResolveCorrelatedColumns resolves the correlated columns in range access +func (ts *PhysicalTableScan) ResolveCorrelatedColumns() ([]*ranger.Range, error) { + access := ts.AccessCondition + if ts.Table.IsCommonHandle { + pkIdx := tables.FindPrimaryIndex(ts.Table) + idxCols, idxColLens := expression.IndexInfo2PrefixCols(ts.Columns, ts.Schema().Columns, pkIdx) + for _, cond := range access { + newCond, err := expression.SubstituteCorCol2Constant(cond) + if err != nil { + return nil, err + } + access = append(access, newCond) + } + res, err := ranger.DetachCondAndBuildRangeForIndex(ts.SCtx(), access, idxCols, idxColLens) + if err != nil { + return nil, err + } + ts.Ranges = res.Ranges + } else { + var err error + pkTP := ts.Table.GetPkColInfo().FieldType + ts.Ranges, err = ranger.BuildTableRange(access, ts.SCtx().GetSessionVars().StmtCtx, &pkTP) + if err != nil { + return nil, err + } + } + return ts.Ranges, nil +} + // ExpandVirtualColumn expands the virtual column's dependent columns to ts's schema and column. func ExpandVirtualColumn(columns []*model.ColumnInfo, schema *expression.Schema, colsInfo []*model.ColumnInfo) []*model.ColumnInfo { @@ -940,9 +970,10 @@ const ( type basePhysicalAgg struct { physicalSchemaProducer - AggFuncs []*aggregation.AggFuncDesc - GroupByItems []expression.Expression - MppRunMode AggMppRunMode + AggFuncs []*aggregation.AggFuncDesc + GroupByItems []expression.Expression + MppRunMode AggMppRunMode + MppPartitionCols []*expression.Column } func (p *basePhysicalAgg) isFinalAgg() bool { diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 0ec72814b934e..aa0b48df459b5 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" @@ -1231,17 +1232,21 @@ func buildPointUpdatePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName OrderedList: orderedList, TblColPosInfos: TblColPosInfoSlice{ TblColPosInfo{ - TblID: tbl.ID, - Start: 0, - End: pointPlan.Schema().Len(), - HandleCols: handleCols, - IsCommonHandle: tbl.IsCommonHandle, + TblID: tbl.ID, + Start: 0, + End: pointPlan.Schema().Len(), + HandleCols: handleCols, }, }, AllAssignmentsAreConstant: allAssignmentsAreConstant, VirtualAssignmentsOffset: len(orderedList), }.Init(ctx) updatePlan.names = pointPlan.OutputNames() + is := infoschema.GetInfoSchema(ctx) + t, _ := is.TableByID(tbl.ID) + updatePlan.tblID2Table = map[int64]table.Table{ + tbl.ID: t, + } return updatePlan } @@ -1318,11 +1323,10 @@ func buildPointDeletePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName SelectPlan: pointPlan, TblColPosInfos: TblColPosInfoSlice{ TblColPosInfo{ - TblID: tbl.ID, - Start: 0, - End: pointPlan.Schema().Len(), - HandleCols: handleCols, - IsCommonHandle: tbl.IsCommonHandle, + TblID: tbl.ID, + Start: 0, + End: pointPlan.Schema().Len(), + HandleCols: handleCols, }, }, }.Init(ctx) diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index 643267b24a3c8..5d77e6b724029 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" @@ -320,7 +321,7 @@ func (s *testPointGetSuite) TestCBOPointGet(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table t (a varchar(20), b int, c int, d int, primary key(a), unique key(b, c))") tk.MustExec("insert into t values('1',4,4,1), ('2',3,3,2), ('3',2,2,3), ('4',1,1,4)") @@ -393,7 +394,7 @@ func (s *testPointGetSuite) TestBatchPointGetPartition(c *C) { c.Assert(err, IsNil) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key, b int) PARTITION BY HASH(a) PARTITIONS 4") tk.MustExec("insert into t values (1, 1), (2, 2), (3, 3), (4, 4)") @@ -566,7 +567,7 @@ func (s *testPointGetSuite) TestBatchPointGetWithInvisibleIndex(c *C) { func (s *testPointGetSuite) TestCBOShouldNotUsePointGet(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop tables if exists t1, t2, t3, t4, t5") tk.MustExec("create table t1(id varchar(20) primary key)") tk.MustExec("create table t2(id varchar(20), unique(id))") diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 832f31435b429..5731495f9c2d2 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -117,6 +117,9 @@ func doPhysicalProjectionElimination(p PhysicalPlan) PhysicalPlan { return p } child := p.Children()[0] + if childProj, ok := child.(*PhysicalProjection); ok { + childProj.SetSchema(p.Schema()) + } return child } diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index f10bc52dc44bc..2896a1dade0ff 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -16,6 +16,7 @@ package core import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" ) @@ -45,6 +46,11 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan { plan.Children()[i] = pe.inject(child) } + if tr, ok := plan.(*PhysicalTableReader); ok && tr.StoreType == kv.TiFlash { + tr.tablePlan = pe.inject(tr.tablePlan) + tr.TablePlans = flattenPushDownPlan(tr.tablePlan) + } + switch p := plan.(type) { case *PhysicalHashAgg: plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems) diff --git a/planner/core/task.go b/planner/core/task.go index 15b47b754ea65..24ed01ead04e6 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tipb/go-tipb" ) @@ -1163,7 +1164,7 @@ func (p *PhysicalTopN) getPushedDownTopN(childPlan PhysicalPlan) *PhysicalTopN { topN := PhysicalTopN{ ByItems: newByItems, Count: newCount, - }.Init(p.ctx, stats, p.blockOffset) + }.Init(p.ctx, stats, p.blockOffset, p.GetChildReqProps(0)) topN.SetChildren(childPlan) return topN } @@ -1433,14 +1434,15 @@ func BuildFinalModeAggregation( partialCursor++ } if aggFunc.Name == ast.AggFuncAvg { - cntAgg := *aggFunc + cntAgg := aggFunc.Clone() cntAgg.Name = ast.AggFuncCount cntAgg.RetTp = partial.Schema.Columns[partialCursor-2].GetType() cntAgg.RetTp.Flag = aggFunc.RetTp.Flag - sumAgg := *aggFunc + // we must call deep clone in this case, to avoid sharing the arguments. + sumAgg := aggFunc.Clone() sumAgg.Name = ast.AggFuncSum sumAgg.RetTp = partial.Schema.Columns[partialCursor-1].GetType() - partial.AggFuncs = append(partial.AggFuncs, &cntAgg, &sumAgg) + partial.AggFuncs = append(partial.AggFuncs, cntAgg, sumAgg) } else if aggFunc.Name == ast.AggFuncApproxCountDistinct { approxCountDistinctAgg := *aggFunc approxCountDistinctAgg.Name = ast.AggFuncApproxCountDistinct @@ -1755,17 +1757,23 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task { return invalidTask } attachPlan2Task(partialAgg, mpp) - items := finalAgg.(*PhysicalHashAgg).GroupByItems - partitionCols := make([]*expression.Column, 0, len(items)) - for _, expr := range items { - col, ok := expr.(*expression.Column) - if !ok { - return invalidTask + partitionCols := p.MppPartitionCols + if len(partitionCols) == 0 { + items := finalAgg.(*PhysicalHashAgg).GroupByItems + partitionCols = make([]*expression.Column, 0, len(items)) + for _, expr := range items { + col, ok := expr.(*expression.Column) + if !ok { + return invalidTask + } + partitionCols = append(partitionCols, col) } - partitionCols = append(partitionCols, col) } prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: partitionCols} newMpp := mpp.enforceExchangerImpl(prop) + if newMpp.invalid() { + return newMpp + } attachPlan2Task(finalAgg, newMpp) if proj != nil { attachPlan2Task(proj, newMpp) @@ -1967,6 +1975,13 @@ func (t *mppTask) enforceExchanger(prop *property.PhysicalProperty) *mppTask { } func (t *mppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *mppTask { + if collate.NewCollationEnabled() && prop.PartitionTp == property.HashType { + for _, col := range prop.PartitionCols { + if types.IsString(col.RetType.Tp) { + return &mppTask{cst: math.MaxFloat64} + } + } + } ctx := t.p.SCtx() sender := PhysicalExchangeSender{ ExchangeType: tipb.ExchangeType(prop.PartitionTp), diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index 1c15f4fe184de..e1b912b6c3d63 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -78,10 +78,19 @@ ] }, { - "name": "TestMPPNotSupportedInNewCollation", + "name": "TestMPPWithHashExchangeUnderNewCollation", "cases": [ "explain format = 'brief' select * from table_1 a, table_1 b where a.id = b.id", - "explain format = 'brief' select /*+ agg_to_cop() */ count(*), id from table_1 group by id" + "explain format = 'brief' select /*+ agg_to_cop() */ count(*), id from table_1 group by id", + "explain format = 'brief' select * from table_1 a, table_1 b where a.value = b.value", + "explain format = 'brief' select /*+ agg_to_cop() */ count(*), value from table_1 group by value" + ] + }, + { + "name": "TestMPPWithBroadcastExchangeUnderNewCollation", + "cases": [ + "explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.id = b.id", + "explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.value = b.value" ] }, { @@ -215,7 +224,9 @@ "desc format = 'brief' select count(distinct value),id from t group by id", "desc format = 'brief' select count(distinct value),sum(distinct value),id from t group by id", "desc format = 'brief' select * from t join ( select count(distinct value), id from t group by id) as A on A.id = t.id", - "desc format = 'brief' select * from t join ( select count(1/value), id from t group by id) as A on A.id = t.id" + "desc format = 'brief' select * from t join ( select count(1/value), id from t group by id) as A on A.id = t.id", + "desc format = 'brief' select /*+hash_agg()*/ sum(id) from (select value, id from t where id > value group by id, value)A group by value /*the exchange should have only one partition column: test.t.value*/", + "desc format = 'brief' select /*+hash_agg()*/ sum(B.value) from t as B where B.id+1 > (select count(*) from t where t.id= B.id and t.value=B.value) group by B.id /*the exchange should have only one partition column: test.t.id*/" ] }, { diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index 9d94e6e16ae8e..e9a2a2ccfd93f 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -22,12 +22,14 @@ "SQL": "explain format = 'brief' select * from t where b > 'a' order by convert(b, unsigned) limit 2", "Plan": [ "Projection 2.00 root test.t.a, test.t.b", - "└─TopN 2.00 root Column#3, offset:0, count:2", - " └─Projection 2.00 root test.t.a, test.t.b, cast(test.t.b, bigint(22) UNSIGNED BINARY)->Column#3", - " └─TableReader 2.00 root data:TopN", - " └─TopN 2.00 batchCop[tiflash] cast(test.t.b, bigint(22) UNSIGNED BINARY), offset:0, count:2", - " └─Selection 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + "└─TopN 2.00 root Column#4, offset:0, count:2", + " └─Projection 2.00 root test.t.a, test.t.b, cast(test.t.b, bigint(22) UNSIGNED BINARY)->Column#4", + " └─TableReader 2.00 root data:Projection", + " └─Projection 2.00 batchCop[tiflash] test.t.a, test.t.b", + " └─TopN 2.00 batchCop[tiflash] Column#3, offset:0, count:2", + " └─Projection 3333.33 batchCop[tiflash] test.t.a, test.t.b, cast(test.t.b, bigint(22) UNSIGNED BINARY)->Column#3", + " └─Selection 3333.33 batchCop[tiflash] gt(test.t.b, \"a\")", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -823,24 +825,86 @@ ] }, { - "Name": "TestMPPNotSupportedInNewCollation", + "Name": "TestMPPWithHashExchangeUnderNewCollation", "Cases": [ { "SQL": "explain format = 'brief' select * from table_1 a, table_1 b where a.id = b.id", "Plan": [ - "HashJoin 2.00 root inner join, equal:[eq(test.table_1.id, test.table_1.id)]", - "├─TableReader(Build) 2.00 root data:TableFullScan", - "│ └─TableFullScan 2.00 cop[tiflash] table:b keep order:false", - "└─TableReader(Probe) 2.00 root data:TableFullScan", - " └─TableFullScan 2.00 cop[tiflash] table:a keep order:false" + "TableReader 2.00 root data:ExchangeSender", + "└─ExchangeSender 2.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashJoin 2.00 cop[tiflash] inner join, equal:[eq(test.table_1.id, test.table_1.id)]", + " ├─ExchangeReceiver(Build) 2.00 cop[tiflash] ", + " │ └─ExchangeSender 2.00 cop[tiflash] ExchangeType: HashPartition, Hash Cols: test.table_1.id", + " │ └─TableFullScan 2.00 cop[tiflash] table:a keep order:false", + " └─ExchangeReceiver(Probe) 2.00 cop[tiflash] ", + " └─ExchangeSender 2.00 cop[tiflash] ExchangeType: HashPartition, Hash Cols: test.table_1.id", + " └─TableFullScan 2.00 cop[tiflash] table:b keep order:false" ] }, { "SQL": "explain format = 'brief' select /*+ agg_to_cop() */ count(*), id from table_1 group by id", "Plan": [ - "HashAgg 2.00 root group by:test.table_1.id, funcs:count(1)->Column#4, funcs:firstrow(test.table_1.id)->test.table_1.id", - "└─TableReader 2.00 root data:TableFullScan", - " └─TableFullScan 2.00 cop[tiflash] table:table_1 keep order:false" + "TableReader 2.00 root data:ExchangeSender", + "└─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─Projection 2.00 batchCop[tiflash] Column#4, test.table_1.id", + " └─HashAgg 2.00 batchCop[tiflash] group by:test.table_1.id, funcs:sum(Column#7)->Column#4, funcs:firstrow(test.table_1.id)->test.table_1.id", + " └─ExchangeReceiver 2.00 batchCop[tiflash] ", + " └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.table_1.id", + " └─HashAgg 2.00 batchCop[tiflash] group by:test.table_1.id, funcs:count(1)->Column#7", + " └─TableFullScan 2.00 batchCop[tiflash] table:table_1 keep order:false" + ] + }, + { + "SQL": "explain format = 'brief' select * from table_1 a, table_1 b where a.value = b.value", + "Plan": [ + "HashJoin 2.00 root inner join, equal:[eq(test.table_1.value, test.table_1.value)]", + "├─TableReader(Build) 2.00 root data:Selection", + "│ └─Selection 2.00 cop[tiflash] not(isnull(test.table_1.value))", + "│ └─TableFullScan 2.00 cop[tiflash] table:b keep order:false", + "└─TableReader(Probe) 2.00 root data:Selection", + " └─Selection 2.00 cop[tiflash] not(isnull(test.table_1.value))", + " └─TableFullScan 2.00 cop[tiflash] table:a keep order:false" + ] + }, + { + "SQL": "explain format = 'brief' select /*+ agg_to_cop() */ count(*), value from table_1 group by value", + "Plan": [ + "HashAgg 2.00 root group by:test.table_1.value, funcs:count(Column#9)->Column#4, funcs:firstrow(test.table_1.value)->test.table_1.value", + "└─TableReader 2.00 root data:ExchangeSender", + " └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─HashAgg 2.00 batchCop[tiflash] group by:test.table_1.value, funcs:count(1)->Column#9", + " └─TableFullScan 2.00 batchCop[tiflash] table:table_1 keep order:false" + ] + } + ] + }, + { + "Name": "TestMPPWithBroadcastExchangeUnderNewCollation", + "Cases": [ + { + "SQL": "explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.id = b.id", + "Plan": [ + "TableReader 2.00 root data:ExchangeSender", + "└─ExchangeSender 2.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashJoin 2.00 cop[tiflash] inner join, equal:[eq(test.table_1.id, test.table_1.id)]", + " ├─ExchangeReceiver(Build) 2.00 cop[tiflash] ", + " │ └─ExchangeSender 2.00 cop[tiflash] ExchangeType: Broadcast", + " │ └─TableFullScan 2.00 cop[tiflash] table:a keep order:false", + " └─TableFullScan(Probe) 2.00 cop[tiflash] table:b keep order:false" + ] + }, + { + "SQL": "explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.value = b.value", + "Plan": [ + "TableReader 2.00 root data:ExchangeSender", + "└─ExchangeSender 2.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashJoin 2.00 cop[tiflash] inner join, equal:[eq(test.table_1.value, test.table_1.value)]", + " ├─ExchangeReceiver(Build) 2.00 cop[tiflash] ", + " │ └─ExchangeSender 2.00 cop[tiflash] ExchangeType: Broadcast", + " │ └─Selection 2.00 cop[tiflash] not(isnull(test.table_1.value))", + " │ └─TableFullScan 2.00 cop[tiflash] table:a keep order:false", + " └─Selection(Probe) 2.00 cop[tiflash] not(isnull(test.table_1.value))", + " └─TableFullScan 2.00 cop[tiflash] table:b keep order:false" ] } ] @@ -853,8 +917,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7, funcs:sum(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(15,4) BINARY)->Column#10", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -863,8 +928,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(test.t.a)->Column#7, funcs:sum(test.t.a)->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7, funcs:sum(Column#10)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(15,4) BINARY)->Column#10", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -873,8 +939,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(test.t.a)->Column#6", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#7)->Column#6", + " └─Projection 10000.00 batchCop[tiflash] cast(test.t.a, decimal(32,0) BINARY)->Column#7", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -883,8 +950,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.a, 1))->Column#6", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#7)->Column#6", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.a, 1), decimal(41,0) BINARY)->Column#7", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -893,8 +961,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:sum(Column#6)->Column#4", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(isnull(test.t.a))->Column#6", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#7)->Column#6", + " └─Projection 10000.00 batchCop[tiflash] cast(isnull(test.t.a), decimal(22,0) BINARY)->Column#7", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, @@ -1203,8 +1272,9 @@ "Plan": [ "HashAgg 1.00 root funcs:count(Column#7)->Column#5", "└─TableReader 1.00 root data:HashAgg", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1221,8 +1291,9 @@ "Plan": [ "HashAgg 1.00 root funcs:sum(Column#7)->Column#5", "└─TableReader 1.00 root data:HashAgg", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1230,8 +1301,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:count(Column#7)->Column#5", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1248,8 +1320,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:sum(Column#7)->Column#5", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1344,8 +1417,9 @@ "HashAgg 1.00 root funcs:count(Column#8)->Column#5", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(plus(test.t.id, 1))->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1364,8 +1438,9 @@ "HashAgg 1.00 root funcs:sum(Column#8)->Column#5", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.id, 1))->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1373,8 +1448,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:count(Column#7)->Column#5", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1391,8 +1467,9 @@ "Plan": [ "StreamAgg 1.00 root funcs:sum(Column#7)->Column#5", "└─TableReader 1.00 root data:StreamAgg", - " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.id, 1))->Column#7", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─StreamAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#7", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1544,7 +1621,7 @@ " └─Projection 7984.01 batchCop[tiflash] Column#14, test.t.c1, test.t.c2, test.t.c3", " └─HashAgg 7984.01 batchCop[tiflash] group by:test.t.c1, test.t.c2, test.t.c3, funcs:sum(Column#23)->Column#14, funcs:firstrow(test.t.c1)->test.t.c1, funcs:firstrow(test.t.c2)->test.t.c2, funcs:firstrow(test.t.c3)->test.t.c3", " └─ExchangeReceiver 7984.01 batchCop[tiflash] ", - " └─ExchangeSender 7984.01 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.c1, test.t.c2, test.t.c3", + " └─ExchangeSender 7984.01 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.c2, test.t.c3, test.t.c1", " └─HashAgg 7984.01 batchCop[tiflash] group by:test.t.c1, test.t.c2, test.t.c3, funcs:count(1)->Column#23", " └─Selection 9980.01 batchCop[tiflash] not(isnull(test.t.c1)), not(isnull(test.t.c2))", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" @@ -1661,8 +1738,9 @@ "HashAgg 1.00 root funcs:count(Column#8)->Column#5", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:count(plus(test.t.id, 1))->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] plus(test.t.id, 1)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1681,8 +1759,9 @@ "HashAgg 1.00 root funcs:sum(Column#8)->Column#5", "└─TableReader 1.00 root data:ExchangeSender", " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", - " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(plus(test.t.id, 1))->Column#8", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 1.00 batchCop[tiflash] funcs:sum(Column#9)->Column#8", + " └─Projection 10000.00 batchCop[tiflash] cast(plus(test.t.id, 1), decimal(41,0) BINARY)->Column#9", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1718,8 +1797,9 @@ " └─HashAgg 8000.00 batchCop[tiflash] group by:Column#10, funcs:sum(Column#11)->Column#4, funcs:firstrow(Column#12)->test.t.id", " └─ExchangeReceiver 8000.00 batchCop[tiflash] ", " └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: Column#10", - " └─HashAgg 8000.00 batchCop[tiflash] group by:plus(test.t.id, 1), funcs:count(1)->Column#11, funcs:firstrow(test.t.id)->Column#12", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 8000.00 batchCop[tiflash] group by:Column#17, funcs:count(1)->Column#11, funcs:firstrow(Column#16)->Column#12", + " └─Projection 10000.00 batchCop[tiflash] test.t.id, plus(test.t.id, 1)->Column#17", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1832,16 +1912,17 @@ "TableReader 7992.00 root data:ExchangeSender", "└─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: PassThrough", " └─Projection 7992.00 batchCop[tiflash] Column#7", - " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#7", - " └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", - " ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ", - " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", - " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", - " └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo" + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#11, funcs:sum(Column#10)->Column#7", + " └─Projection 12487.50 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#10, test.t.id", + " └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", + " ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ", + " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", + " └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo" ] }, { @@ -1857,18 +1938,19 @@ " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─Projection(Probe) 7992.00 batchCop[tiflash] Column#11, test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#11, funcs:firstrow(test.t.id)->test.t.id", - " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", - " ├─Projection(Build) 7992.00 batchCop[tiflash] test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", - " │ └─ExchangeReceiver 9990.00 batchCop[tiflash] ", - " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", - " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", - " └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#32, funcs:sum(Column#30)->Column#11, funcs:firstrow(Column#31)->test.t.id", + " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#30, test.t.id, test.t.id", + " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", + " ├─Projection(Build) 7992.00 batchCop[tiflash] test.t.id", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", + " │ └─ExchangeReceiver 9990.00 batchCop[tiflash] ", + " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", + " └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] }, { @@ -1921,14 +2003,53 @@ " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#8)->Column#7, funcs:firstrow(test.t.id)->test.t.id", " │ └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:count(div(1, test.t.value))->Column#8", - " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:Column#19, funcs:count(Column#18)->Column#8", + " │ └─Projection 9990.00 batchCop[tiflash] div(1, test.t.value)->Column#18, test.t.id", + " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", " └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", " └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] + }, + { + "SQL": "desc format = 'brief' select /*+hash_agg()*/ sum(id) from (select value, id from t where id > value group by id, value)A group by value /*the exchange should have only one partition column: test.t.value*/", + "Plan": [ + "TableReader 6400.00 root data:ExchangeSender", + "└─ExchangeSender 6400.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─Projection 6400.00 batchCop[tiflash] Column#4", + " └─HashAgg 6400.00 batchCop[tiflash] group by:Column#22, funcs:sum(Column#21)->Column#4", + " └─Projection 6400.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#21, test.t.value", + " └─Projection 6400.00 batchCop[tiflash] test.t.id, test.t.value", + " └─HashAgg 6400.00 batchCop[tiflash] group by:test.t.id, test.t.value, funcs:firstrow(test.t.id)->test.t.id, funcs:firstrow(test.t.value)->test.t.value", + " └─ExchangeReceiver 8000.00 batchCop[tiflash] ", + " └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.value", + " └─Selection 8000.00 batchCop[tiflash] gt(cast(test.t.id, decimal(20,0) BINARY), test.t.value)", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "desc format = 'brief' select /*+hash_agg()*/ sum(B.value) from t as B where B.id+1 > (select count(*) from t where t.id= B.id and t.value=B.value) group by B.id /*the exchange should have only one partition column: test.t.id*/", + "Plan": [ + "TableReader 6400.00 root data:ExchangeSender", + "└─ExchangeSender 6400.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─Projection 6400.00 batchCop[tiflash] Column#8", + " └─HashAgg 6400.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.value)->Column#8", + " └─Selection 8000.00 batchCop[tiflash] gt(plus(test.t.id, 1), ifnull(Column#7, 0))", + " └─HashJoin 10000.00 batchCop[tiflash] left outer join, equal:[eq(test.t.id, test.t.id) eq(test.t.value, test.t.value)]", + " ├─Selection(Build) 6387.21 batchCop[tiflash] gt(plus(test.t.id, 1), ifnull(Column#7, 0))", + " │ └─Projection 7984.01 batchCop[tiflash] Column#7, test.t.id, test.t.value", + " │ └─HashAgg 7984.01 batchCop[tiflash] group by:test.t.id, test.t.value, funcs:sum(Column#24)->Column#7, funcs:firstrow(test.t.id)->test.t.id, funcs:firstrow(test.t.value)->test.t.value", + " │ └─ExchangeReceiver 7984.01 batchCop[tiflash] ", + " │ └─ExchangeSender 7984.01 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " │ └─HashAgg 7984.01 batchCop[tiflash] group by:test.t.id, test.t.value, funcs:count(1)->Column#24", + " │ └─Selection 9980.01 batchCop[tiflash] not(isnull(test.t.id)), not(isnull(test.t.value))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " └─ExchangeReceiver(Probe) 10000.00 batchCop[tiflash] ", + " └─ExchangeSender 10000.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " └─TableFullScan 10000.00 batchCop[tiflash] table:B keep order:false, stats:pseudo" + ] } ] }, @@ -2018,14 +2139,15 @@ " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#8)->Column#7", " └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#8", - " └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", - " ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ", - " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: Broadcast", - " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", - " └─Selection(Probe) 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo" + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#11, funcs:sum(Column#10)->Column#8", + " └─Projection 12487.50 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#10, test.t.id", + " └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", + " ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ", + " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: Broadcast", + " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " └─Selection(Probe) 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo" ] }, { @@ -2069,18 +2191,19 @@ " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#14)->Column#11, funcs:firstrow(test.t.id)->test.t.id", " └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#14", - " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", - " ├─ExchangeReceiver(Build) 7992.00 batchCop[tiflash] ", - " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 7992.00 batchCop[tiflash] test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", - " │ └─ExchangeReceiver 9990.00 batchCop[tiflash] ", - " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", - " └─Selection(Probe) 9990.00 batchCop[tiflash] not(isnull(test.t.id))", - " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#27, funcs:sum(Column#26)->Column#14", + " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#26, test.t.id", + " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", + " ├─ExchangeReceiver(Build) 7992.00 batchCop[tiflash] ", + " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: Broadcast", + " │ └─Projection 7992.00 batchCop[tiflash] test.t.id", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", + " │ └─ExchangeReceiver 9990.00 batchCop[tiflash] ", + " │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", + " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", + " └─Selection(Probe) 9990.00 batchCop[tiflash] not(isnull(test.t.id))", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ] } ] diff --git a/server/conn_test.go b/server/conn_test.go index 39dd348607b03..9b4f26f27cde9 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/store/tikv" @@ -698,7 +699,7 @@ func (ts *ConnTestSuite) TestPrefetchPointKeys(c *C) { tk := testkit.NewTestKitWithInit(c, ts.store) cc.ctx = &TiDBContext{Session: tk.Se} ctx := context.Background() - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table prefetch (a int, b int, c int, primary key (a, b))") tk.MustExec("insert prefetch values (1, 1, 1), (2, 2, 2), (3, 3, 3)") tk.MustExec("begin optimistic") @@ -802,11 +803,11 @@ func (ts *ConnTestSuite) TestTiFlashFallback(c *C) { tk.MustExec("set @@tidb_allow_batch_cop=1; set @@tidb_allow_mpp=0;") c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0", "return(\"tiflash0\")"), IsNil) - testFallbackWork(c, tk, cc, "select sum(a) from t") + testFallbackWork(c, tk, cc, "select count(*) from t") c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/BatchCopRpcErrtiflash0"), IsNil) c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/batchCopRecvTimeout", "return(true)"), IsNil) - testFallbackWork(c, tk, cc, "select sum(a) from t") + testFallbackWork(c, tk, cc, "select count(*) from t") c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/batchCopRecvTimeout"), IsNil) // TiFlash MPP query (MPP + streaming) diff --git a/server/server_test.go b/server/server_test.go index 1d97e510c3770..0a68bbfd29fde 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -41,7 +41,6 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/versioninfo" "go.uber.org/zap" @@ -113,7 +112,7 @@ func (cli *testServerClient) getDSN(overriders ...configOverrider) string { config.Net = "tcp" config.Addr = fmt.Sprintf("127.0.0.1:%d", cli.port) config.DBName = "test" - config.Params = map[string]string{variable.TiDBIntPrimaryKeyDefaultAsClustered: "true"} + config.Params = make(map[string]string) for _, overrider := range overriders { if overrider != nil { overrider(config) diff --git a/session/clustered_index_test.go b/session/clustered_index_test.go index 4fd94072c12ba..660bfeb900feb 100644 --- a/session/clustered_index_test.go +++ b/session/clustered_index_test.go @@ -15,7 +15,9 @@ package session_test import ( . "github.com/pingcap/check" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/testkit" @@ -31,7 +33,7 @@ type testClusteredSerialSuite struct{ testClusteredSuiteBase } func (s *testClusteredSuiteBase) newTK(c *C) *testkit.TestKit { tk := testkit.NewTestKitWithInit(c, s.store) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn return tk } @@ -307,6 +309,86 @@ func (s *testClusteredSuite) TestClusteredPrefixingPrimaryKey(c *C) { tk.MustQuery(`select /*+ INL_MERGE_JOIN(t1,t2) */ * from t1, t2 where t1.c_int = t2.c_int and t1.c_str >= t2.c_str;`).Check(testkit.Rows("1 nifty elion 1 funny shaw")) } +func (s *testClusteredSerialSuite) TestCreateClusteredTable(c *C) { + tk := s.newTK(c) + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly + tk.MustExec("drop table if exists t1, t2, t3, t4, t5, t6, t7, t8") + tk.MustExec("create table t1(id int primary key, v int)") + tk.MustExec("create table t2(id varchar(10) primary key, v int)") + tk.MustExec("create table t3(id int primary key clustered, v int)") + tk.MustExec("create table t4(id varchar(10) primary key clustered, v int)") + tk.MustExec("create table t5(id int primary key nonclustered, v int)") + tk.MustExec("create table t6(id varchar(10) primary key nonclustered, v int)") + tk.MustExec("create table t7(id varchar(10), v int, primary key (id) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("create table t8(id varchar(10), v int, primary key (id) /*T![clustered_index] NONCLUSTERED */)") + tk.MustQuery("show index from t1").Check(testkit.Rows("t1 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t2").Check(testkit.Rows("t2 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t3").Check(testkit.Rows("t3 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t4").Check(testkit.Rows("t4 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t5").Check(testkit.Rows("t5 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t6").Check(testkit.Rows("t6 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t7").Check(testkit.Rows("t7 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t8").Check(testkit.Rows("t8 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOff + tk.MustExec("drop table if exists t1, t2, t3, t4, t5, t6, t7, t8") + tk.MustExec("create table t1(id int primary key, v int)") + tk.MustExec("create table t2(id varchar(10) primary key, v int)") + tk.MustExec("create table t3(id int primary key clustered, v int)") + tk.MustExec("create table t4(id varchar(10) primary key clustered, v int)") + tk.MustExec("create table t5(id int primary key nonclustered, v int)") + tk.MustExec("create table t6(id varchar(10) primary key nonclustered, v int)") + tk.MustExec("create table t7(id varchar(10), v int, primary key (id) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("create table t8(id varchar(10), v int, primary key (id) /*T![clustered_index] NONCLUSTERED */)") + tk.MustQuery("show index from t1").Check(testkit.Rows("t1 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t2").Check(testkit.Rows("t2 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t3").Check(testkit.Rows("t3 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t4").Check(testkit.Rows("t4 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t5").Check(testkit.Rows("t5 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t6").Check(testkit.Rows("t6 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t7").Check(testkit.Rows("t7 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t8").Check(testkit.Rows("t8 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn + tk.MustExec("drop table if exists t1, t2, t3, t4, t5, t6, t7, t8") + tk.MustExec("create table t1(id int primary key, v int)") + tk.MustExec("create table t2(id varchar(10) primary key, v int)") + tk.MustExec("create table t3(id int primary key clustered, v int)") + tk.MustExec("create table t4(id varchar(10) primary key clustered, v int)") + tk.MustExec("create table t5(id int primary key nonclustered, v int)") + tk.MustExec("create table t6(id varchar(10) primary key nonclustered, v int)") + tk.MustExec("create table t7(id varchar(10), v int, primary key (id) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("create table t8(id varchar(10), v int, primary key (id) /*T![clustered_index] NONCLUSTERED */)") + tk.MustQuery("show index from t1").Check(testkit.Rows("t1 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t2").Check(testkit.Rows("t2 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t3").Check(testkit.Rows("t3 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t4").Check(testkit.Rows("t4 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t5").Check(testkit.Rows("t5 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t6").Check(testkit.Rows("t6 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t7").Check(testkit.Rows("t7 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t8").Check(testkit.Rows("t8 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.AlterPrimaryKey = true + }) + tk.MustExec("drop table if exists t1, t2, t3, t4, t5, t6, t7, t8") + tk.MustExec("create table t1(id int primary key, v int)") + tk.MustExec("create table t2(id varchar(10) primary key, v int)") + tk.MustExec("create table t3(id int primary key clustered, v int)") + tk.MustExec("create table t4(id varchar(10) primary key clustered, v int)") + tk.MustExec("create table t5(id int primary key nonclustered, v int)") + tk.MustExec("create table t6(id varchar(10) primary key nonclustered, v int)") + tk.MustExec("create table t7(id varchar(10), v int, primary key (id) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("create table t8(id varchar(10), v int, primary key (id) /*T![clustered_index] NONCLUSTERED */)") + tk.MustQuery("show index from t1").Check(testkit.Rows("t1 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t2").Check(testkit.Rows("t2 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t3").Check(testkit.Rows("t3 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t4").Check(testkit.Rows("t4 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t5").Check(testkit.Rows("t5 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t6").Check(testkit.Rows("t6 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) + tk.MustQuery("show index from t7").Check(testkit.Rows("t7 0 PRIMARY 1 id A 0 BTREE YES NULL YES")) + tk.MustQuery("show index from t8").Check(testkit.Rows("t8 0 PRIMARY 1 id A 0 BTREE YES NULL NO")) +} + // Test for union scan in prefixed clustered index table. // See https://github.com/pingcap/tidb/issues/22069. func (s *testClusteredSerialSuite) TestClusteredUnionScanOnPrefixingPrimaryKey(c *C) { @@ -435,11 +517,11 @@ func (s *testClusteredSuite) TestClusteredIndexSyntax(c *C) { assertPkType("create table t (a int, b int, primary key(a) /*T![clustered_index] nonclustered */);", nonClustered) // Test for clustered index. - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly assertPkType("create table t (a int, b varchar(255), primary key(b, a));", nonClustered) assertPkType("create table t (a int, b varchar(255), primary key(b, a) nonclustered);", nonClustered) assertPkType("create table t (a int, b varchar(255), primary key(b, a) clustered);", clustered) - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn assertPkType("create table t (a int, b varchar(255), primary key(b, a));", clusteredDefault) assertPkType("create table t (a int, b varchar(255), primary key(b, a) nonclustered);", nonClustered) assertPkType("create table t (a int, b varchar(255), primary key(b, a) /*T![clustered_index] nonclustered */);", nonClustered) @@ -488,7 +570,7 @@ func (s *testClusteredSerialSuite) TestClusteredIndexDecodeRestoredDataV5(c *C) defer collate.SetNewCollationEnabledForTest(false) collate.SetNewCollationEnabledForTest(true) tk.MustExec("use test") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("drop table if exists t;") tk.MustExec("create table t (id1 int, id2 varchar(10), a1 int, primary key(id1, id2) clustered) collate utf8mb4_general_ci;") tk.MustExec("insert into t values (1, 'asd', 1), (1, 'dsa', 1);") @@ -510,7 +592,7 @@ func (s *testClusteredSerialSuite) TestPrefixedClusteredIndexUniqueKeyWithNewCol collate.SetNewCollationEnabledForTest(true) tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("use test;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a text collate utf8mb4_general_ci not null, b int(11) not null, " + "primary key (a(10), b) clustered, key idx(a(2)) ) default charset=utf8mb4 collate=utf8mb4_bin;") tk.MustExec("insert into t values ('aaa', 2);") diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index 90146de91594d..3341484d1cccb 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -47,6 +47,12 @@ func (s *testPessimisticSuite) newAsyncCommitTestKitWithInit(c *C) *testkit.Test return tk } +func (s *testPessimisticSuite) new1PCTestKitWithInit(c *C) *testkit.TestKit { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.Se.GetSessionVars().Enable1PC = true + return tk +} + type testPessimisticSuite struct { testSessionSuiteBase } @@ -2122,9 +2128,9 @@ func (s *testPessimisticSuite) Test1PCWithSchemaChange(c *C) { conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 }) - tk := s.newAsyncCommitTestKitWithInit(c) - tk2 := s.newAsyncCommitTestKitWithInit(c) - tk3 := s.newAsyncCommitTestKitWithInit(c) + tk := s.new1PCTestKitWithInit(c) + tk2 := s.new1PCTestKitWithInit(c) + tk3 := s.new1PCTestKitWithInit(c) tk.MustExec("drop table if exists tk") tk.MustExec("create table tk (c1 int primary key, c2 int)") @@ -2160,7 +2166,7 @@ func (s *testPessimisticSuite) Test1PCWithSchemaChange(c *C) { time.Sleep(200 * time.Millisecond) tk2.MustExec("alter table tk add index k2(c2)") }() - c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/beforePrewrite", "1*sleep(1000)"), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/beforePrewrite", "1*sleep(1200)"), IsNil) tk.MustExec("commit") c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/beforePrewrite"), IsNil) tk3.MustExec("admin check table tk") diff --git a/session/session.go b/session/session.go index 42b28a6a48b0d..2753423590176 100644 --- a/session/session.go +++ b/session/session.go @@ -984,7 +984,7 @@ func (s *session) SetGlobalSysVar(name, value string) error { return err } } - variable.CheckDeprecationSetSystemVar(s.sessionVars, name) + variable.CheckDeprecationSetSystemVar(s.sessionVars, name, sVal) stmt, err := s.ParseWithParams(context.TODO(), "REPLACE %n.%n VALUES (%?, %?)", mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) if err != nil { return err @@ -2060,7 +2060,6 @@ func CreateSession4TestWithOpt(store kv.Storage, opt *Opt) (Session, error) { // initialize session variables for test. s.GetSessionVars().InitChunkSize = 2 s.GetSessionVars().MaxChunkSize = 32 - s.GetSessionVars().IntPrimaryKeyDefaultAsClustered = true } return s, err } diff --git a/session/tidb_test.go b/session/tidb_test.go index 80da89191d9a9..02e9571c18482 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/store/mockstore" + tikvstore "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/logutil" @@ -210,7 +211,7 @@ func (s *testMainSuite) TestKeysNeedLock(c *C) { for _, tt := range tests { c.Assert(keyNeedToLock(tt.key, tt.val, 0), Equals, tt.need) } - flag := kv.KeyFlags(1) + flag := tikvstore.KeyFlags(1) c.Assert(flag.HasPresumeKeyNotExists(), IsTrue) c.Assert(keyNeedToLock(indexKey, deleteVal, flag), IsTrue) } diff --git a/session/txn.go b/session/txn.go index bccf130fb988a..2cec543ed072b 100644 --- a/session/txn.go +++ b/session/txn.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/binloginfo" + tikvstore "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/logutil" @@ -281,7 +282,7 @@ func (txn *TxnState) KeysNeedToLock() ([]kv.Key, error) { } keys := make([]kv.Key, 0, txn.countHint()) buf := txn.Transaction.GetMemBuffer() - buf.InspectStage(txn.stagingHandle, func(k kv.Key, flags kv.KeyFlags, v []byte) { + buf.InspectStage(txn.stagingHandle, func(k kv.Key, flags tikvstore.KeyFlags, v []byte) { if !keyNeedToLock(k, v, flags) { return } @@ -290,7 +291,7 @@ func (txn *TxnState) KeysNeedToLock() ([]kv.Key, error) { return keys, nil } -func keyNeedToLock(k, v []byte, flags kv.KeyFlags) bool { +func keyNeedToLock(k, v []byte, flags tikvstore.KeyFlags) bool { isTableKey := bytes.HasPrefix(k, tablecodec.TablePrefix()) if !isTableKey { // meta key always need to lock. diff --git a/sessionctx/binloginfo/binloginfo_test.go b/sessionctx/binloginfo/binloginfo_test.go index 03a003543ada6..c059288228ba2 100644 --- a/sessionctx/binloginfo/binloginfo_test.go +++ b/sessionctx/binloginfo/binloginfo_test.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/binloginfo" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" @@ -199,7 +200,7 @@ func (s *testBinlogSuite) TestBinlog(c *C) { c.Assert(gotRows, DeepEquals, expected) // Test table primary key is not integer. - tk.Se.GetSessionVars().EnableClusteredIndex = false + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly tk.MustExec("create table local_binlog2 (name varchar(64) primary key, age int)") tk.MustExec("insert local_binlog2 values ('abc', 16), ('def', 18)") tk.MustExec("delete from local_binlog2 where name = 'def'") diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 3d7da1432b206..427db8c69abcd 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -762,7 +762,7 @@ type SessionVars struct { SelectLimit uint64 // EnableClusteredIndex indicates whether to enable clustered index when creating a new table. - EnableClusteredIndex bool + EnableClusteredIndex ClusteredIndexDefMode // PresumeKeyNotExists indicates lazy existence checking is enabled. PresumeKeyNotExists bool @@ -818,10 +818,6 @@ type SessionVars struct { // AllowFallbackToTiKV indicates the engine types whose unavailability triggers fallback to TiKV. // Now we only support TiFlash. AllowFallbackToTiKV map[kv.StoreType]struct{} - - // IntPrimaryKeyDefaultAsClustered indicates whether create integer primary table as clustered - // If it's true, the behavior is the same as the TiDB 4.0 and the below versions. - IntPrimaryKeyDefaultAsClustered bool } // CheckAndGetTxnScope will return the transaction scope we should use in the current session. @@ -1704,7 +1700,7 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { case TiDBAllowAutoRandExplicitInsert: s.AllowAutoRandExplicitInsert = TiDBOptOn(val) case TiDBEnableClusteredIndex: - s.EnableClusteredIndex = TiDBOptOn(val) + s.EnableClusteredIndex = TiDBOptEnableClustered(val) case TiDBPartitionPruneMode: s.PartitionPruneMode.Store(strings.ToLower(strings.TrimSpace(val))) case TiDBEnableParallelApply: @@ -1762,8 +1758,6 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.AllowFallbackToTiKV[kv.TiFlash] = struct{}{} } } - case TiDBIntPrimaryKeyDefaultAsClustered: - s.IntPrimaryKeyDefaultAsClustered = TiDBOptOn(val) default: sv := GetSysVar(name) if err := sv.SetSessionFromHook(s, val); err != nil { diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 678c811ad95fd..407d74a351dbc 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -79,6 +79,9 @@ const ( // Warn means return warnings Warn = "WARN" + + // IntOnly means enable for int type + IntOnly = "INT_ONLY" ) // SysVar is for system variable. @@ -705,7 +708,6 @@ var defaultSysVars = []*SysVar{ } return formatVal, nil }}, - {Scope: ScopeGlobal | ScopeSession, Name: TiDBIntPrimaryKeyDefaultAsClustered, Value: BoolToOnOff(false), Type: TypeBool}, /* The following variable is defined as session scope but is actually server scope. */ {Scope: ScopeSession, Name: TiDBGeneralLog, Value: BoolToOnOff(DefTiDBGeneralLog), Type: TypeBool}, {Scope: ScopeSession, Name: TiDBPProfSQLCPU, Value: strconv.Itoa(DefTiDBPProfSQLCPU), Type: TypeInt, MinValue: 0, MaxValue: 1}, @@ -779,7 +781,7 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeSession, Name: TiDBFoundInBinding, Value: BoolToOnOff(DefTiDBFoundInBinding), Type: TypeBool, ReadOnly: true}, {Scope: ScopeSession, Name: TiDBEnableCollectExecutionInfo, Value: BoolToOnOff(DefTiDBEnableCollectExecutionInfo), Type: TypeBool}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBAllowAutoRandExplicitInsert, Value: BoolToOnOff(DefTiDBAllowAutoRandExplicitInsert), Type: TypeBool}, - {Scope: ScopeGlobal, Name: TiDBEnableClusteredIndex, Value: BoolToOnOff(DefTiDBEnableClusteredIndex), Type: TypeBool}, + {Scope: ScopeGlobal, Name: TiDBEnableClusteredIndex, Value: IntOnly, Type: TypeEnum, PossibleValues: []string{Off, On, IntOnly, "1", "0"}}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBPartitionPruneMode, Value: string(Static), Type: TypeStr, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { mode := PartitionPruneMode(normalizedValue).Update() if !mode.Valid() { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 5c769245415ce..cfddf24407fe0 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -529,9 +529,6 @@ const ( // TiDBAllowFallbackToTiKV indicates the engine types whose unavailability triggers fallback to TiKV. // Now we only support TiFlash. TiDBAllowFallbackToTiKV = "tidb_allow_fallback_to_tikv" - - // TiDBIntPrimaryKeyDefaultAsClustered indicates whether create int primary key as clustered as 4.0 behavior. - TiDBIntPrimaryKeyDefaultAsClustered = "tidb_int_primary_key_default_as_clustered" ) // TiDB vars that have only global scope @@ -657,7 +654,7 @@ const ( DefTiDBFoundInBinding = false DefTiDBEnableCollectExecutionInfo = true DefTiDBAllowAutoRandExplicitInsert = false - DefTiDBEnableClusteredIndex = false + DefTiDBEnableClusteredIndex = ClusteredIndexDefModeIntOnly DefTiDBRedactLog = false DefTiDBShardAllocateStep = math.MaxInt64 DefTiDBEnableTelemetry = true @@ -708,11 +705,9 @@ var FeatureSwitchVariables = []string{ TiDBEnableAsyncCommit, TiDBEnable1PC, TiDBGuaranteeLinearizability, - TiDBEnableClusteredIndex, TiDBTrackAggregateMemoryUsage, TiDBAnalyzeVersion, TiDBPartitionPruneMode, - TiDBIntPrimaryKeyDefaultAsClustered, TiDBEnableExtendedStats, TiDBEnableIndexMergeJoin, } diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index d7ca6debcaae0..9039eab93e733 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -239,7 +239,7 @@ func SetSessionSystemVar(vars *SessionVars, name string, value types.Datum) erro if err != nil { return err } - CheckDeprecationSetSystemVar(vars, name) + CheckDeprecationSetSystemVar(vars, name, sVal) return vars.SetSystemVar(name, sVal) } @@ -254,7 +254,7 @@ func SetStmtVar(vars *SessionVars, name string, value string) error { if err != nil { return err } - CheckDeprecationSetSystemVar(vars, name) + CheckDeprecationSetSystemVar(vars, name, sVal) return vars.SetStmtVar(name, sVal) } @@ -285,7 +285,7 @@ const ( ) // CheckDeprecationSetSystemVar checks if the system variable is deprecated. -func CheckDeprecationSetSystemVar(s *SessionVars, name string) { +func CheckDeprecationSetSystemVar(s *SessionVars, name string, val string) { switch name { case TiDBIndexLookupConcurrency, TiDBIndexLookupJoinConcurrency, TiDBHashJoinConcurrency, TiDBHashAggPartialConcurrency, TiDBHashAggFinalConcurrency, @@ -295,6 +295,10 @@ func CheckDeprecationSetSystemVar(s *SessionVars, name string) { TIDBMemQuotaSort, TIDBMemQuotaTopn, TIDBMemQuotaIndexLookupReader, TIDBMemQuotaIndexLookupJoin: s.StmtCtx.AppendWarning(errWarnDeprecatedSyntax.FastGenByArgs(name, TIDBMemQuotaQuery)) + case TiDBEnableClusteredIndex: + if strings.EqualFold(val, IntOnly) { + s.StmtCtx.AppendWarning(errWarnDeprecatedSyntax.FastGenByArgs(val, fmt.Sprintf("'%s' or '%s'", On, Off))) + } } } @@ -339,6 +343,30 @@ func TiDBOptMultiStmt(opt string) int { return WarnInt } +// ClusteredIndexDefMode controls the default clustered property for primary key. +type ClusteredIndexDefMode int + +const ( + // ClusteredIndexDefModeIntOnly indicates only single int primary key will default be clustered. + ClusteredIndexDefModeIntOnly ClusteredIndexDefMode = 0 + // ClusteredIndexDefModeOn indicates primary key will default be clustered. + ClusteredIndexDefModeOn ClusteredIndexDefMode = 1 + // ClusteredIndexDefModeOff indicates primary key will default be non-clustered. + ClusteredIndexDefModeOff ClusteredIndexDefMode = 2 +) + +// TiDBOptEnableClustered converts enable clustered options to ClusteredIndexDefMode. +func TiDBOptEnableClustered(opt string) ClusteredIndexDefMode { + switch { + case strings.EqualFold(opt, "ON") || opt == "1": + return ClusteredIndexDefModeOn + case strings.EqualFold(opt, "OFF") || opt == "0": + return ClusteredIndexDefModeOff + default: + return ClusteredIndexDefModeIntOnly + } +} + func tidbOptPositiveInt32(opt string, defaultVal int) int { val, err := strconv.Atoi(opt) if err != nil || val <= 0 { diff --git a/statistics/handle/gc_test.go b/statistics/handle/gc_test.go index 9acaf8234a0f6..f934f8973fe6c 100644 --- a/statistics/handle/gc_test.go +++ b/statistics/handle/gc_test.go @@ -113,23 +113,23 @@ func (s *testStatsSuite) TestGCExtendedStats(c *C) { h := s.do.StatsHandle() ddlLease := time.Duration(0) c.Assert(h.GCStats(s.do.InfoSchema(), ddlLease), IsNil) - testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Check(testkit.Rows( + testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Sort().Check(testkit.Rows( "s1 2 [1,2] 1.000000 2", "s2 2 [2,3] 1.000000 1", )) c.Assert(h.GCStats(s.do.InfoSchema(), ddlLease), IsNil) - testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Check(testkit.Rows( + testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Sort().Check(testkit.Rows( "s2 2 [2,3] 1.000000 1", )) testKit.MustExec("drop table t") - testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Check(testkit.Rows( + testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Sort().Check(testkit.Rows( "s2 2 [2,3] 1.000000 1", )) c.Assert(h.GCStats(s.do.InfoSchema(), ddlLease), IsNil) - testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Check(testkit.Rows( + testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Sort().Check(testkit.Rows( "s2 2 [2,3] 1.000000 2", )) c.Assert(h.GCStats(s.do.InfoSchema(), ddlLease), IsNil) - testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Check(testkit.Rows()) + testKit.MustQuery("select name, type, column_ids, stats, status from mysql.stats_extended").Sort().Check(testkit.Rows()) } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index a73fb151bb225..ad1539869bbea 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -462,8 +462,10 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, opts map } else { // For the index stats, we get the final NDV by accumulating the NDV of each bucket in the index histogram. globalStatsNDV := int64(0) - for _, bucket := range globalStats.Hg[i].Buckets { - globalStatsNDV += bucket.NDV + for j := range globalStats.Hg[i].Buckets { + globalStatsNDV += globalStats.Hg[i].Buckets[j].NDV + // NOTICE: after merging bucket NDVs have the trend to be underestimated, so for safe we don't use them. + globalStats.Hg[i].Buckets[j].NDV = 0 } globalStats.Hg[i].NDV = globalStatsNDV diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index c5bb159f0b520..98130285356d1 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -989,8 +989,8 @@ partition by range (a) ( "test t p1 a 0 0 6 1 11 16 0", "test t p1 a 0 1 10 2 17 19 0")) tk.MustQuery("show stats_buckets where is_index=1").Check( - testkit.Rows("test t global a 1 0 7 2 1 6 6", - "test t global a 1 1 17 2 6 19 9", + testkit.Rows("test t global a 1 0 7 2 1 6 0", + "test t global a 1 1 17 2 6 19 0", "test t p0 a 1 0 4 1 1 4 4", "test t p0 a 1 1 7 2 5 6 2", "test t p1 a 1 0 8 1 11 18 8", @@ -1035,7 +1035,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where is_index=0").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv "test tint global c 0 0 5 2 1 4 0", // bucket.ndv is not maintained for column histograms - "test tint global c 0 1 12 2 4 17 0", + "test tint global c 0 1 12 2 17 17 0", "test tint p0 c 0 0 2 1 1 2 0", "test tint p0 c 0 1 3 1 3 3 0", "test tint p1 c 0 0 3 1 11 13 0", @@ -1048,8 +1048,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where is_index=1").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv - "test tint global c 1 0 5 0 1 5 4", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 - "test tint global c 1 1 12 2 5 17 6", + "test tint global c 1 0 5 2 1 4 0", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 + "test tint global c 1 1 12 2 4 17 0", "test tint p0 c 1 0 3 0 1 4 3", "test tint p0 c 1 1 3 0 5 5 0", "test tint p1 c 1 0 5 0 11 16 5", @@ -1094,7 +1094,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdouble' and is_index=0 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv "test tdouble global c 0 0 5 2 1 4 0", // bucket.ndv is not maintained for column histograms - "test tdouble global c 0 1 12 2 4 17 0", + "test tdouble global c 0 1 12 2 17 17 0", "test tdouble p0 c 0 0 2 1 1 2 0", "test tdouble p0 c 0 1 3 1 3 3 0", "test tdouble p1 c 0 0 3 1 11 13 0", @@ -1110,8 +1110,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdouble' and is_index=1 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv - "test tdouble global c 1 0 5 0 1 5 4", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 - "test tdouble global c 1 1 12 2 5 17 6", + "test tdouble global c 1 0 5 2 1 4 0", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 + "test tdouble global c 1 1 12 2 4 17 0", "test tdouble p0 c 1 0 3 0 1 4 3", "test tdouble p0 c 1 1 3 0 5 5 0", "test tdouble p1 c 1 0 5 0 11 16 5", @@ -1159,7 +1159,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdecimal' and is_index=0 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv "test tdecimal global c 0 0 5 2 1.00 4.00 0", // bucket.ndv is not maintained for column histograms - "test tdecimal global c 0 1 12 2 4.00 17.00 0", + "test tdecimal global c 0 1 12 2 17.00 17.00 0", "test tdecimal p0 c 0 0 2 1 1.00 2.00 0", "test tdecimal p0 c 0 1 3 1 3.00 3.00 0", "test tdecimal p1 c 0 0 3 1 11.00 13.00 0", @@ -1175,8 +1175,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdecimal' and is_index=1 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv - "test tdecimal global c 1 0 5 0 1.00 5.00 4", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 - "test tdecimal global c 1 1 12 2 5.00 17.00 6", + "test tdecimal global c 1 0 5 2 1.00 4.00 0", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 + "test tdecimal global c 1 1 12 2 4.00 17.00 0", "test tdecimal p0 c 1 0 3 0 1.00 4.00 3", "test tdecimal p0 c 1 1 3 0 5.00 5.00 0", "test tdecimal p1 c 1 0 5 0 11.00 16.00 5", @@ -1224,7 +1224,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdatetime' and is_index=0 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv "test tdatetime global c 0 0 5 2 2000-01-01 00:00:00 2000-01-04 00:00:00 0", // bucket.ndv is not maintained for column histograms - "test tdatetime global c 0 1 12 2 2000-01-04 00:00:00 2000-01-17 00:00:00 0", + "test tdatetime global c 0 1 12 2 2000-01-17 00:00:00 2000-01-17 00:00:00 0", "test tdatetime p0 c 0 0 2 1 2000-01-01 00:00:00 2000-01-02 00:00:00 0", "test tdatetime p0 c 0 1 3 1 2000-01-03 00:00:00 2000-01-03 00:00:00 0", "test tdatetime p1 c 0 0 3 1 2000-01-11 00:00:00 2000-01-13 00:00:00 0", @@ -1240,8 +1240,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tdatetime' and is_index=1 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv - "test tdatetime global c 1 0 5 0 2000-01-01 00:00:00 2000-01-05 00:00:00 4", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 - "test tdatetime global c 1 1 12 2 2000-01-05 00:00:00 2000-01-17 00:00:00 6", + "test tdatetime global c 1 0 5 2 2000-01-01 00:00:00 2000-01-04 00:00:00 0", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 + "test tdatetime global c 1 1 12 2 2000-01-04 00:00:00 2000-01-17 00:00:00 0", "test tdatetime p0 c 1 0 3 0 2000-01-01 00:00:00 2000-01-04 00:00:00 3", "test tdatetime p0 c 1 1 3 0 2000-01-05 00:00:00 2000-01-05 00:00:00 0", "test tdatetime p1 c 1 0 5 0 2000-01-11 00:00:00 2000-01-16 00:00:00 5", @@ -1289,7 +1289,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tstring' and is_index=0 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv "test tstring global c 0 0 5 2 a1 a4 0", // bucket.ndv is not maintained for column histograms - "test tstring global c 0 1 12 2 a4 b17 0", + "test tstring global c 0 1 12 2 b17 b17 0", "test tstring p0 c 0 0 2 1 a1 a2 0", "test tstring p0 c 0 1 3 1 a3 a3 0", "test tstring p1 c 0 0 3 1 b11 b13 0", @@ -1305,8 +1305,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("show stats_buckets where table_name='tstring' and is_index=1 and column_name='c'").Check(testkit.Rows( // db, tbl, part, col, isIdx, bucketID, count, repeat, lower, upper, ndv - "test tstring global c 1 0 5 0 a1 a5 4", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 - "test tstring global c 1 1 12 2 a5 b17 6", + "test tstring global c 1 0 5 2 a1 a4 0", // 4 is popped from p0.TopN, so g.ndv = p0.ndv+1 + "test tstring global c 1 1 12 2 a4 b17 0", "test tstring p0 c 1 0 3 0 a1 a4 3", "test tstring p0 c 1 1 3 0 a5 a5 0", "test tstring p1 c 1 0 5 0 b11 b16 5", @@ -1350,8 +1350,8 @@ func (s *testStatsSuite) TestGlobalStatsData3(c *C) { "test tintint p1 a 1 (13, 2) 3")) tk.MustQuery("show stats_buckets where table_name='tintint' and is_index=1").Check(testkit.Rows( - "test tintint global a 1 0 6 0 (1, 1) (3, 1) 5", // (2, 3) is popped into it - "test tintint global a 1 1 11 0 (3, 1) (13, 2) 4", // (13, 1) is popped into it + "test tintint global a 1 0 6 2 (1, 1) (2, 3) 0", // (2, 3) is popped into it + "test tintint global a 1 1 11 2 (2, 3) (13, 1) 0", // (13, 1) is popped into it "test tintint p0 a 1 0 4 1 (1, 1) (2, 2) 4", "test tintint p0 a 1 1 4 0 (2, 3) (3, 1) 0", "test tintint p1 a 1 0 3 0 (11, 1) (13, 1) 3", @@ -1384,8 +1384,8 @@ func (s *testStatsSuite) TestGlobalStatsData3(c *C) { "test tintstr p1 a 1 (13, 2) 3")) tk.MustQuery("show stats_buckets where table_name='tintstr' and is_index=1").Check(testkit.Rows( - "test tintstr global a 1 0 6 0 (1, 1) (3, 1) 5", // (2, 3) is popped into it - "test tintstr global a 1 1 11 0 (3, 1) (13, 2) 4", // (13, 1) is popped into it + "test tintstr global a 1 0 6 2 (1, 1) (2, 3) 0", // (2, 3) is popped into it + "test tintstr global a 1 1 11 2 (2, 3) (13, 1) 0", // (13, 1) is popped into it "test tintstr p0 a 1 0 4 1 (1, 1) (2, 2) 4", "test tintstr p0 a 1 1 4 0 (2, 3) (3, 1) 0", "test tintstr p1 a 1 0 3 0 (11, 1) (13, 1) 3", @@ -1418,8 +1418,8 @@ func (s *testStatsSuite) TestGlobalStatsData3(c *C) { "test tintdouble p1 a 1 (13, 2) 3")) tk.MustQuery("show stats_buckets where table_name='tintdouble' and is_index=1").Check(testkit.Rows( - "test tintdouble global a 1 0 6 0 (1, 1) (3, 1) 5", // (2, 3) is popped into it - "test tintdouble global a 1 1 11 0 (3, 1) (13, 2) 4", // (13, 1) is popped into it + "test tintdouble global a 1 0 6 2 (1, 1) (2, 3) 0", // (2, 3) is popped into it + "test tintdouble global a 1 1 11 2 (2, 3) (13, 1) 0", // (13, 1) is popped into it "test tintdouble p0 a 1 0 4 1 (1, 1) (2, 2) 4", "test tintdouble p0 a 1 1 4 0 (2, 3) (3, 1) 0", "test tintdouble p1 a 1 0 3 0 (11, 1) (13, 1) 3", @@ -1452,8 +1452,8 @@ func (s *testStatsSuite) TestGlobalStatsData3(c *C) { "test tdoubledecimal p1 a 1 (13, 2.00) 3")) tk.MustQuery("show stats_buckets where table_name='tdoubledecimal' and is_index=1").Check(testkit.Rows( - "test tdoubledecimal global a 1 0 6 0 (1, 1.00) (3, 1.00) 5", // (2, 3) is popped into it - "test tdoubledecimal global a 1 1 11 0 (3, 1.00) (13, 2.00) 4", // (13, 1) is popped into it + "test tdoubledecimal global a 1 0 6 2 (1, 1.00) (2, 3.00) 0", // (2, 3) is popped into it + "test tdoubledecimal global a 1 1 11 2 (2, 3.00) (13, 1.00) 0", // (13, 1) is popped into it "test tdoubledecimal p0 a 1 0 4 1 (1, 1.00) (2, 2.00) 4", "test tdoubledecimal p0 a 1 1 4 0 (2, 3.00) (3, 1.00) 0", "test tdoubledecimal p1 a 1 0 3 0 (11, 1.00) (13, 1.00) 3", @@ -1486,8 +1486,8 @@ func (s *testStatsSuite) TestGlobalStatsData3(c *C) { "test tstrdt p1 a 1 (13, 2000-01-02 00:00:00) 3")) tk.MustQuery("show stats_buckets where table_name='tstrdt' and is_index=1").Check(testkit.Rows( - "test tstrdt global a 1 0 6 0 (1, 2000-01-01 00:00:00) (3, 2000-01-01 00:00:00) 5", // (2, 3) is popped into it - "test tstrdt global a 1 1 11 0 (3, 2000-01-01 00:00:00) (13, 2000-01-02 00:00:00) 4", // (13, 1) is popped into it + "test tstrdt global a 1 0 6 2 (1, 2000-01-01 00:00:00) (2, 2000-01-03 00:00:00) 0", // (2, 3) is popped into it + "test tstrdt global a 1 1 11 2 (2, 2000-01-03 00:00:00) (13, 2000-01-01 00:00:00) 0", // (13, 1) is popped into it "test tstrdt p0 a 1 0 4 1 (1, 2000-01-01 00:00:00) (2, 2000-01-02 00:00:00) 4", "test tstrdt p0 a 1 1 4 0 (2, 2000-01-03 00:00:00) (3, 2000-01-01 00:00:00) 0", "test tstrdt p1 a 1 0 3 0 (11, 2000-01-01 00:00:00) (13, 2000-01-01 00:00:00) 3", diff --git a/statistics/histogram.go b/statistics/histogram.go index e1b1c11c74c6b..877f36956ccb5 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1756,7 +1756,10 @@ func mergePartitionBuckets(sc *stmtctx.StatementContext, buckets []*bucket4Mergi res := bucket4Merging{} res.upper = buckets[len(buckets)-1].upper.Clone() right := buckets[len(buckets)-1].Clone() + + totNDV := int64(0) for i := len(buckets) - 1; i >= 0; i-- { + totNDV += buckets[i].NDV res.Count += buckets[i].Count compare, err := buckets[i].upper.CompareDatum(sc, res.upper) if err != nil { @@ -1774,6 +1777,14 @@ func mergePartitionBuckets(sc *stmtctx.StatementContext, buckets []*bucket4Mergi } } res.NDV = right.NDV + right.disjointNDV + + // since `mergeBucketNDV` is based on uniform and inclusion assumptions, it has the trend to under-estimate, + // and as the number of buckets increases, these assumptions become weak, + // so to mitigate this problem, a damping factor based on the number of buckets is introduced. + res.NDV = int64(float64(res.NDV) * math.Pow(1.15, float64(len(buckets)-1))) + if res.NDV > totNDV { + res.NDV = totNDV + } return &res, nil } @@ -1855,6 +1866,16 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog buckets = append(buckets, meta.buildBucket4Merging(&d)) } + // Remove empty buckets + tail := 0 + for i := range buckets { + if buckets[i].Count != 0 { + buckets[tail] = buckets[i] + tail++ + } + } + buckets = buckets[:tail] + var sortError error sort.Slice(buckets, func(i, j int) bool { res, err := buckets[i].upper.CompareDatum(sc, buckets[j].upper) @@ -1873,14 +1894,17 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog if sortError != nil { return nil, sortError } - var sum int64 - r := len(buckets) + + var sum, prevSum int64 + r, prevR := len(buckets), 0 bucketCount := int64(1) + gBucketCountThreshold := (totCount / expBucketNumber) * 80 / 100 // expectedBucketSize * 0.8 + var bucketNDV int64 for i := len(buckets) - 1; i >= 0; i-- { sum += buckets[i].Count - if sum >= totCount*bucketCount/expBucketNumber { - // if the buckets have the same upper, we merge them into the same new buckets. - for ; i > 0; i-- { + bucketNDV += buckets[i].NDV + if sum >= totCount*bucketCount/expBucketNumber && sum-prevSum >= gBucketCountThreshold { + for ; i > 0; i-- { // if the buckets have the same upper, we merge them into the same new buckets. res, err := buckets[i-1].upper.CompareDatum(sc, buckets[i].upper) if err != nil { return nil, err @@ -1888,18 +1912,33 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog if res != 0 { break } + sum += buckets[i-1].Count + bucketNDV += buckets[i-1].NDV } merged, err := mergePartitionBuckets(sc, buckets[i:r]) if err != nil { return nil, err } globalBuckets = append(globalBuckets, merged) + prevR = r r = i bucketCount++ + prevSum = sum + bucketNDV = 0 } } if r > 0 { - merged, err := mergePartitionBuckets(sc, buckets[0:r]) + bucketSum := int64(0) + for _, b := range buckets[:r] { + bucketSum += b.Count + } + + if len(globalBuckets) > 0 && bucketSum < gBucketCountThreshold { // merge them into the previous global bucket + r = prevR + globalBuckets = globalBuckets[:len(globalBuckets)-1] + } + + merged, err := mergePartitionBuckets(sc, buckets[:r]) if err != nil { return nil, err } @@ -1911,12 +1950,16 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog } // Calc the bucket lower. - if minValue == nil { // both hists and popedTopN are empty, returns an empty hist in this case + if minValue == nil || len(globalBuckets) == 0 { // both hists and popedTopN are empty, returns an empty hist in this case return NewHistogram(hists[0].ID, 0, totNull, hists[0].LastUpdateVersion, hists[0].Tp, len(globalBuckets), totColSize), nil } globalBuckets[0].lower = minValue.Clone() for i := 1; i < len(globalBuckets); i++ { - globalBuckets[i].lower = globalBuckets[i-1].upper.Clone() + if globalBuckets[i].NDV == 1 { // there is only 1 value so lower = upper + globalBuckets[i].lower = globalBuckets[i].upper.Clone() + } else { + globalBuckets[i].lower = globalBuckets[i-1].upper.Clone() + } globalBuckets[i].Count = globalBuckets[i].Count + globalBuckets[i-1].Count } diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index 6a906628d9686..57958fe759279 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -250,7 +250,7 @@ func (s *testStatisticsSuite) TestMergePartitionLevelHist(c *C) { upper: 7, count: 7, repeat: 3, - ndv: 4, + ndv: 5, }, { lower: 7, @@ -264,7 +264,7 @@ func (s *testStatisticsSuite) TestMergePartitionLevelHist(c *C) { upper: 17, count: 22, repeat: 1, - ndv: 5, + ndv: 6, }, }, expBucketNumber: 3, @@ -358,14 +358,14 @@ func (s *testStatisticsSuite) TestMergePartitionLevelHist(c *C) { upper: 12, count: 22, repeat: 3, - ndv: 5, + ndv: 6, }, { lower: 12, upper: 18, count: 33, repeat: 5, - ndv: 5, + ndv: 6, }, }, expBucketNumber: 3, diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index f371d2189b0b7..21bfeaf0114b4 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/store/mockstore" @@ -512,7 +513,7 @@ func (s *testStatsSuite) TestPrimaryKeySelectivity(c *C) { testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") testKit.MustExec("drop table if exists t") - testKit.Se.GetSessionVars().EnableClusteredIndex = false + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly testKit.MustExec("create table t(a char(10) primary key, b int)") var input, output [][]string s.testData.GetTestCases(c, &input, &output) @@ -672,7 +673,7 @@ func (s *testStatsSuite) TestUniqCompEqualEst(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") - testKit.Se.GetSessionVars().EnableClusteredIndex = true + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn testKit.MustExec("drop table if exists t") testKit.MustExec("create table t(a int, b int, primary key(a, b))") testKit.MustExec("insert into t values(1,1),(1,2),(1,3),(1,4),(1,5),(1,6),(1,7),(1,8),(1,9),(1,10)") diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index e5ce311030425..fa2dc90e0d9d6 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -121,7 +121,7 @@ func buildBatchCopTasks(bo *tikv.Backoffer, cache *tikv.RegionCache, ranges *tik storeTaskMap := make(map[string]*batchCopTask) needRetry := false for _, task := range tasks { - rpcCtx, err := cache.GetTiFlashRPCContext(bo, task.region) + rpcCtx, err := cache.GetTiFlashRPCContext(bo, task.region, false) if err != nil { return nil, errors.Trace(err) } diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 46d57949fcc84..13488ae5b3e03 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -179,20 +179,22 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, m.wg.Done() }() var regionInfos []*coprocessor.RegionInfo - originalTask := req.Meta.(*batchCopTask) - for _, task := range originalTask.copTasks { - regionInfos = append(regionInfos, &coprocessor.RegionInfo{ - RegionId: task.task.region.GetID(), - RegionEpoch: &metapb.RegionEpoch{ - ConfVer: task.task.region.GetConfVer(), - Version: task.task.region.GetVer(), - }, - Ranges: task.task.ranges.ToPBRanges(), - }) + originalTask, ok := req.Meta.(*batchCopTask) + if ok { + for _, task := range originalTask.copTasks { + regionInfos = append(regionInfos, &coprocessor.RegionInfo{ + RegionId: task.task.region.GetID(), + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: task.task.region.GetConfVer(), + Version: task.task.region.GetVer(), + }, + Ranges: task.task.ranges.ToPBRanges(), + }) + } } // meta for current task. - taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, TaskId: req.ID, Address: originalTask.storeAddr} + taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, TaskId: req.ID, Address: req.Meta.GetAddress()} mppReq := &mpp.DispatchTaskRequest{ Meta: taskMeta, @@ -212,7 +214,7 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, // If copTasks is not empty, we should send request according to region distribution. // Or else it's the task without region, which always happens in high layer task without table. // In that case - if len(originalTask.copTasks) != 0 { + if originalTask != nil { sender := NewRegionBatchRequestSender(m.store.GetRegionCache(), m.store.GetTiKVClient()) rpcResp, _, _, err = sender.sendStreamReqToAddr(bo, originalTask.copTasks, wrappedReq, tikv.ReadTimeoutMedium) // No matter what the rpc error is, we won't retry the mpp dispatch tasks. @@ -225,7 +227,7 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, return } } else { - rpcResp, err = m.store.GetTiKVClient().SendRequest(ctx, originalTask.storeAddr, wrappedReq, tikv.ReadTimeoutMedium) + rpcResp, err = m.store.GetTiKVClient().SendRequest(ctx, req.Meta.GetAddress(), wrappedReq, tikv.ReadTimeoutMedium) } if err != nil { @@ -280,7 +282,7 @@ func (m *mppIterator) cancelMppTasks() { // send cancel cmd to all stores where tasks run for addr := range usedStoreAddrs { - _, err := m.store.GetTiKVClient().SendRequest(context.Background(), addr, wrappedReq, tikv.ReadTimeoutUltraLong) + _, err := m.store.GetTiKVClient().SendRequest(context.Background(), addr, wrappedReq, tikv.ReadTimeoutShort) logutil.BgLogger().Debug("cancel task ", zap.Uint64("query id ", m.startTs), zap.String(" on addr ", addr)) if err != nil { logutil.BgLogger().Error("cancel task error: ", zap.Error(err), zap.Uint64(" for query id ", m.startTs), zap.String(" on addr ", addr)) diff --git a/store/driver/txn/txn_driver.go b/store/driver/txn/txn_driver.go index 95e3d916ac1a6..6ab90b089e084 100644 --- a/store/driver/txn/txn_driver.go +++ b/store/driver/txn/txn_driver.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/store/tikv/logutil" + "github.com/pingcap/tidb/store/tikv/unionstore" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -65,6 +66,14 @@ func (txn *tikvTxn) GetSnapshot() kv.Snapshot { return txn.KVTxn.GetSnapshot() } +func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { + return txn.KVTxn.GetMemBuffer() +} + +func (txn *tikvTxn) GetUnionStore() kv.UnionStore { + return &tikvUnionStore{txn.KVTxn.GetUnionStore()} +} + func (txn *tikvTxn) extractKeyErr(err error) error { if e, ok := errors.Cause(err).(*tikv.ErrKeyExist); ok { return txn.extractKeyExistsErr(e.GetKey()) @@ -82,8 +91,7 @@ func (txn *tikvTxn) extractKeyExistsErr(key kv.Key) error { if tblInfo == nil { return genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find table info")) } - - value, err := txn.GetUnionStore().GetMemBuffer().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) + value, err := txn.KVTxn.GetUnionStore().GetMemBuffer().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) if err != nil { return genKeyExistsError("UNKNOWN", key.String(), err) } @@ -193,3 +201,12 @@ func extractKeyExistsErrFromIndex(key kv.Key, value []byte, tblInfo *model.Table } return genKeyExistsError(name, strings.Join(valueStr, "-"), nil) } + +//tikvUnionStore implements kv.UnionStore +type tikvUnionStore struct { + *unionstore.KVUnionStore +} + +func (u *tikvUnionStore) GetMemBuffer() kv.MemBuffer { + return u.KVUnionStore.GetMemBuffer() +} diff --git a/store/mockstore/mocktikv/analyze.go b/store/mockstore/mocktikv/analyze.go index a575f5536015d..2a013d63313a2 100644 --- a/store/mockstore/mocktikv/analyze.go +++ b/store/mockstore/mocktikv/analyze.go @@ -33,7 +33,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response { resp := &coprocessor.Response{} if len(req.Ranges) == 0 { return resp @@ -62,7 +62,7 @@ func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coproces return resp } -func (h *rpcHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) { +func (h coprHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) { ranges, err := h.extractKVRanges(req.Ranges, false) if err != nil { return nil, errors.Trace(err) @@ -125,7 +125,7 @@ type analyzeColumnsExec struct { fields []*ast.ResultField } -func (h *rpcHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { +func (h coprHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { sc := flagsToStatementContext(analyzeReq.Flags) sc.TimeZone, err = constructTimeZone("", int(analyzeReq.TimeZoneOffset)) if err != nil { diff --git a/store/mockstore/mocktikv/checksum.go b/store/mockstore/mocktikv/checksum.go index 13f54d26ab5a6..5c99a55ee70bf 100644 --- a/store/mockstore/mocktikv/checksum.go +++ b/store/mockstore/mocktikv/checksum.go @@ -20,7 +20,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (h *rpcHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response { resp := &tipb.ChecksumResponse{ Checksum: 1, TotalKvs: 1, diff --git a/store/mockstore/mocktikv/cop_handler_dag.go b/store/mockstore/mocktikv/cop_handler_dag.go index d020d058467ee..82c75e99bb69f 100644 --- a/store/mockstore/mocktikv/cop_handler_dag.go +++ b/store/mockstore/mocktikv/cop_handler_dag.go @@ -54,7 +54,7 @@ type dagContext struct { evalCtx *evalContext } -func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response { resp := &coprocessor.Response{} dagCtx, e, dagReq, err := h.buildDAGExecutor(req) if err != nil { @@ -88,7 +88,7 @@ func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor. return buildResp(selResp, execDetails, err) } -func (h *rpcHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) { +func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) { if len(req.Ranges) == 0 { return nil, nil, nil, errors.New("request range is null") } @@ -133,7 +133,7 @@ func constructTimeZone(name string, offset int) (*time.Location, error) { return timeutil.ConstructTimeZone(name, offset) } -func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) { +func (h coprHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) { dagCtx, e, dagReq, err := h.buildDAGExecutor(req) if err != nil { return nil, errors.Trace(err) @@ -147,7 +147,7 @@ func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Reque }, nil } -func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) { +func (h coprHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) { var currExec executor var err error var childExec *tipb.Executor @@ -179,7 +179,7 @@ func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, return currExec, childExec, errors.Trace(err) } -func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) { +func (h coprHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) { curr, child, err := h.buildExec(ctx, farther) if err != nil { return nil, errors.Trace(err) @@ -194,7 +194,7 @@ func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) return curr, nil } -func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) { +func (h coprHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) { var src executor for i := 0; i < len(executors); i++ { curr, _, err := h.buildExec(ctx, executors[i]) @@ -207,7 +207,7 @@ func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (exec return src, nil } -func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) { +func (h coprHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) { columns := executor.TblScan.Columns ctx.evalCtx.setColumnInfo(columns) ranges, err := h.extractKVRanges(ctx.keyRanges, executor.TblScan.Desc) @@ -258,7 +258,7 @@ func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (* return e, nil } -func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) { +func (h coprHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) { var err error columns := executor.IdxScan.Columns ctx.evalCtx.setColumnInfo(columns) @@ -311,7 +311,7 @@ func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (* return e, nil } -func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) { +func (h coprHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) { var err error var relatedColOffsets []int pbConds := executor.Selection.Conditions @@ -335,7 +335,7 @@ func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (* }, nil } -func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) { +func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) { length := len(executor.Aggregation.AggFunc) aggs := make([]aggregation.Aggregation, 0, length) var err error @@ -366,7 +366,7 @@ func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg return aggs, groupBys, relatedColOffsets, nil } -func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) { +func (h coprHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) { aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor) if err != nil { return nil, errors.Trace(err) @@ -384,7 +384,7 @@ func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*ha }, nil } -func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) { +func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) { aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor) if err != nil { return nil, errors.Trace(err) @@ -406,7 +406,7 @@ func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (* }, nil } -func (h *rpcHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) { +func (h coprHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) { topN := executor.TopN var err error var relatedColOffsets []int @@ -664,7 +664,7 @@ func (mock *mockCopStreamClient) readBlockFromExecutor() (tipb.Chunk, bool, *cop return chunk, finish, &ran, mock.exec.Counts(), warnings, nil } -func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse { +func (h coprHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse { selResp := &tipb.SelectResponse{ Error: toPBError(err), OutputCounts: counts, @@ -675,7 +675,7 @@ func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, c return selResp } -func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error { +func (h coprHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error { switch dagReq.EncodeType { case tipb.EncodeType_TypeDefault: h.encodeDefault(selResp, rows, dagReq.OutputOffsets) @@ -690,7 +690,7 @@ func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dag return nil } -func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType { +func (h coprHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType { var root *tipb.Executor if len(dagCtx.dagReq.Executors) == 0 { root = dagCtx.dagReq.RootExecutor @@ -717,7 +717,7 @@ func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType return schema } -func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) { +func (h coprHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) { var chunks []tipb.Chunk for i := range rows { requestedRow := dummySlice @@ -730,7 +730,7 @@ func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte selResp.EncodeType = tipb.EncodeType_TypeDefault } -func (h *rpcHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error { +func (h coprHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error { var chunks []tipb.Chunk respColTypes := make([]*types.FieldType, 0, len(colOrdinal)) for _, ordinal := range colOrdinal { @@ -826,7 +826,7 @@ func toPBError(err error) *tipb.Error { } // extractKVRanges extracts kv.KeyRanges slice from a SelectRequest. -func (h *rpcHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { +func (h coprHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { for _, kran := range keyRanges { if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 { err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd()) diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go index 2ef026b408249..320f545c550ca 100644 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/golang/protobuf/proto" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -32,7 +31,6 @@ import ( "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/terror" - "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tipb/go-tipb" @@ -141,132 +139,13 @@ func convertToPbPairs(pairs []Pair) []*kvrpcpb.KvPair { return kvPairs } -// rpcHandler mocks tikv's side handler behavior. In general, you may assume +// kvHandler mocks tikv's side handler behavior. In general, you may assume // TiKV just translate the logic from Go to Rust. -type rpcHandler struct { - cluster *Cluster - mvccStore MVCCStore - - // storeID stores id for current request - storeID uint64 - // startKey is used for handling normal request. - startKey []byte - endKey []byte - // rawStartKey is used for handling coprocessor request. - rawStartKey []byte - rawEndKey []byte - // isolationLevel is used for current request. - isolationLevel kvrpcpb.IsolationLevel - resolvedLocks []uint64 +type kvHandler struct { + *Session } -func isTiFlashStore(store *metapb.Store) bool { - for _, l := range store.GetLabels() { - if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash { - return true - } - } - return false -} - -func (h *rpcHandler) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error { - ctxPeer := ctx.GetPeer() - if ctxPeer != nil && ctxPeer.GetStoreId() != h.storeID { - return &errorpb.Error{ - Message: *proto.String("store not match"), - StoreNotMatch: &errorpb.StoreNotMatch{}, - } - } - region, leaderID := h.cluster.GetRegion(ctx.GetRegionId()) - // No region found. - if region == nil { - return &errorpb.Error{ - Message: *proto.String("region not found"), - RegionNotFound: &errorpb.RegionNotFound{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - var storePeer, leaderPeer *metapb.Peer - for _, p := range region.Peers { - if p.GetStoreId() == h.storeID { - storePeer = p - } - if p.GetId() == leaderID { - leaderPeer = p - } - } - // The Store does not contain a Peer of the Region. - if storePeer == nil { - return &errorpb.Error{ - Message: *proto.String("region not found"), - RegionNotFound: &errorpb.RegionNotFound{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - // No leader. - if leaderPeer == nil { - return &errorpb.Error{ - Message: *proto.String("no leader"), - NotLeader: &errorpb.NotLeader{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - // The Peer on the Store is not leader. If it's tiflash store , we pass this check. - if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(h.cluster.GetStore(storePeer.GetStoreId())) { - return &errorpb.Error{ - Message: *proto.String("not leader"), - NotLeader: &errorpb.NotLeader{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - Leader: leaderPeer, - }, - } - } - // Region epoch does not match. - if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) { - nextRegion, _ := h.cluster.GetRegionByKey(region.GetEndKey()) - currentRegions := []*metapb.Region{region} - if nextRegion != nil { - currentRegions = append(currentRegions, nextRegion) - } - return &errorpb.Error{ - Message: *proto.String("epoch not match"), - EpochNotMatch: &errorpb.EpochNotMatch{ - CurrentRegions: currentRegions, - }, - } - } - h.startKey, h.endKey = region.StartKey, region.EndKey - h.isolationLevel = ctx.IsolationLevel - h.resolvedLocks = ctx.ResolvedLocks - return nil -} - -func (h *rpcHandler) checkRequestSize(size int) *errorpb.Error { - // TiKV has a limitation on raft log size. - // mocktikv has no raft inside, so we check the request's size instead. - if size >= requestMaxSize { - return &errorpb.Error{ - RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{}, - } - } - return nil -} - -func (h *rpcHandler) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error { - if err := h.checkRequestContext(ctx); err != nil { - return err - } - return h.checkRequestSize(size) -} - -func (h *rpcHandler) checkKeyInRegion(key []byte) bool { - return regionContains(h.startKey, h.endKey, NewMvccKey(key)) -} - -func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { +func (h kvHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { if !h.checkKeyInRegion(req.Key) { panic("KvGet: key not in region") } @@ -282,7 +161,7 @@ func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { } } -func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse { +func (h kvHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse { endKey := MvccKey(h.endKey).Raw() var pairs []Pair if !req.Reverse { @@ -314,7 +193,7 @@ func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanRespons } } -func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse { +func (h kvHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse { regionID := req.Context.RegionId h.cluster.handleDelay(req.StartVersion, regionID) @@ -329,7 +208,7 @@ func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.Pre } } -func (h *rpcHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse { +func (h kvHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse { for _, m := range req.Mutations { if !h.checkKeyInRegion(m.Key) { panic("KvPessimisticLock: key not in region") @@ -350,7 +229,7 @@ func simulateServerSideWaitLock(errs []error) { } } -func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse { +func (h kvHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse { for _, key := range req.Keys { if !h.checkKeyInRegion(key) { panic("KvPessimisticRollback: key not in region") @@ -362,7 +241,7 @@ func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbac } } -func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse { +func (h kvHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { panic("KvCommit: key not in region") @@ -376,7 +255,7 @@ func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitR return &resp } -func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse { +func (h kvHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse { if !h.checkKeyInRegion(req.Key) { panic("KvCleanup: key not in region") } @@ -392,7 +271,7 @@ func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.Clean return &resp } -func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse { +func (h kvHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse { if !h.checkKeyInRegion(req.PrimaryKey) { panic("KvCheckTxnStatus: key not in region") } @@ -406,7 +285,7 @@ func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) return &resp } -func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse { +func (h kvHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse { if !h.checkKeyInRegion(req.PrimaryLock) { panic("KvTxnHeartBeat: key not in region") } @@ -419,7 +298,7 @@ func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpc return &resp } -func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse { +func (h kvHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { panic("KvBatchGet: key not in region") @@ -431,7 +310,7 @@ func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.Bat } } -func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse { +func (h kvHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse { debugger, ok := h.mvccStore.(MVCCDebugger) if !ok { return &kvrpcpb.MvccGetByKeyResponse{ @@ -447,7 +326,7 @@ func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpc return &resp } -func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse { +func (h kvHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse { debugger, ok := h.mvccStore.(MVCCDebugger) if !ok { return &kvrpcpb.MvccGetByStartTsResponse{ @@ -459,7 +338,7 @@ func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest return &resp } -func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse { +func (h kvHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse { err := h.mvccStore.Rollback(req.Keys, req.StartVersion) if err != nil { return &kvrpcpb.BatchRollbackResponse{ @@ -469,7 +348,7 @@ func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *k return &kvrpcpb.BatchRollbackResponse{} } -func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse { +func (h kvHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() locks, err := h.mvccStore.ScanLock(startKey, endKey, req.GetMaxVersion()) @@ -483,7 +362,7 @@ func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.Sca } } -func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse { +func (h kvHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() err := h.mvccStore.ResolveLock(startKey, endKey, req.GetStartVersion(), req.GetCommitVersion()) @@ -495,7 +374,7 @@ func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpc return &kvrpcpb.ResolveLockResponse{} } -func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { +func (h kvHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() err := h.mvccStore.GC(startKey, endKey, req.GetSafePoint()) @@ -507,7 +386,7 @@ func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { return &kvrpcpb.GCResponse{} } -func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse { +func (h kvHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse { if !h.checkKeyInRegion(req.StartKey) { panic("KvDeleteRange: key not in region") } @@ -519,7 +398,7 @@ func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpc return &resp } -func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse { +func (h kvHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawGetResponse{ @@ -531,7 +410,7 @@ func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetR } } -func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse { +func (h kvHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { // TODO should we add error ? @@ -554,7 +433,7 @@ func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpc } } -func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse { +func (h kvHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawPutResponse{ @@ -565,7 +444,7 @@ func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutR return &kvrpcpb.RawPutResponse{} } -func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse { +func (h kvHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawBatchPutResponse{ @@ -582,7 +461,7 @@ func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpc return &kvrpcpb.RawBatchPutResponse{} } -func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse { +func (h kvHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawDeleteResponse{ @@ -593,7 +472,7 @@ func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.R return &kvrpcpb.RawDeleteResponse{} } -func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse { +func (h kvHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawBatchDeleteResponse{ @@ -604,7 +483,7 @@ func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) return &kvrpcpb.RawBatchDeleteResponse{} } -func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse { +func (h kvHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawDeleteRangeResponse{ @@ -615,7 +494,7 @@ func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) return &kvrpcpb.RawDeleteRangeResponse{} } -func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse { +func (h kvHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { errStr := "not implemented" @@ -654,7 +533,7 @@ func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawSc } } -func (h *rpcHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse { +func (h kvHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse { keys := req.GetSplitKeys() resp := &kvrpcpb.SplitRegionResponse{Regions: make([]*metapb.Region, 0, len(keys)+1)} for i, key := range keys { @@ -690,7 +569,11 @@ func drainRowsFromExecutor(ctx context.Context, e executor, req *tipb.DAGRequest } } -func (h *rpcHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) { +type coprHandler struct { + *Session +} + +func (h coprHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) { client := &mockBatchCopDataClient{} for _, ri := range req.Regions { cop := coprocessor.Request{ @@ -766,7 +649,7 @@ func (c *RPCClient) getAndCheckStoreByAddr(addr string) (*metapb.Store, error) { return nil, errors.New("connection refused") } -func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, error) { +func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*Session, error) { if err := checkGoContext(ctx); err != nil { return nil, err } @@ -775,13 +658,13 @@ func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, er if err != nil { return nil, err } - handler := &rpcHandler{ + session := &Session{ cluster: c.Cluster, mvccStore: c.MvccStore, // set store id for current request storeID: store.GetId(), } - return handler, nil + return session, nil } // GRPCClientFactory is the GRPC client factory. @@ -828,25 +711,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return c.redirectRequestToRPCServer(ctx, addr, req, timeout) } - handler, err := c.checkArgs(ctx, addr) + session, err := c.checkArgs(ctx, addr) if err != nil { return nil, err } switch req.Type { case tikvrpc.CmdGet: r := req.Get() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.GetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvGet(r) + resp.Resp = kvHandler{session}.handleKvGet(r) case tikvrpc.CmdScan: r := req.Scan() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ScanResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvScan(r) + resp.Resp = kvHandler{session}.handleKvScan(r) case tikvrpc.CmdPrewrite: failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { @@ -859,25 +742,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) r := req.Prewrite() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PrewriteResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPrewrite(r) + resp.Resp = kvHandler{session}.handleKvPrewrite(r) case tikvrpc.CmdPessimisticLock: r := req.PessimisticLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PessimisticLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPessimisticLock(r) + resp.Resp = kvHandler{session}.handleKvPessimisticLock(r) case tikvrpc.CmdPessimisticRollback: r := req.PessimisticRollback() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PessimisticRollbackResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPessimisticRollback(r) + resp.Resp = kvHandler{session}.handleKvPessimisticRollback(r) case tikvrpc.CmdCommit: failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { switch val.(string) { @@ -895,11 +778,11 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) r := req.Commit() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CommitResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCommit(r) + resp.Resp = kvHandler{session}.handleKvCommit(r) failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, undeterminedErr) @@ -907,122 +790,122 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) case tikvrpc.CmdCleanup: r := req.Cleanup() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CleanupResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCleanup(r) + resp.Resp = kvHandler{session}.handleKvCleanup(r) case tikvrpc.CmdCheckTxnStatus: r := req.CheckTxnStatus() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CheckTxnStatusResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCheckTxnStatus(r) + resp.Resp = kvHandler{session}.handleKvCheckTxnStatus(r) case tikvrpc.CmdTxnHeartBeat: r := req.TxnHeartBeat() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.TxnHeartBeatResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleTxnHeartBeat(r) + resp.Resp = kvHandler{session}.handleTxnHeartBeat(r) case tikvrpc.CmdBatchGet: r := req.BatchGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.BatchGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvBatchGet(r) + resp.Resp = kvHandler{session}.handleKvBatchGet(r) case tikvrpc.CmdBatchRollback: r := req.BatchRollback() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.BatchRollbackResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvBatchRollback(r) + resp.Resp = kvHandler{session}.handleKvBatchRollback(r) case tikvrpc.CmdScanLock: r := req.ScanLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ScanLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvScanLock(r) + resp.Resp = kvHandler{session}.handleKvScanLock(r) case tikvrpc.CmdResolveLock: r := req.ResolveLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ResolveLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvResolveLock(r) + resp.Resp = kvHandler{session}.handleKvResolveLock(r) case tikvrpc.CmdGC: r := req.GC() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.GCResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvGC(r) + resp.Resp = kvHandler{session}.handleKvGC(r) case tikvrpc.CmdDeleteRange: r := req.DeleteRange() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.DeleteRangeResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvDeleteRange(r) + resp.Resp = kvHandler{session}.handleKvDeleteRange(r) case tikvrpc.CmdRawGet: r := req.RawGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawGet(r) + resp.Resp = kvHandler{session}.handleKvRawGet(r) case tikvrpc.CmdRawBatchGet: r := req.RawBatchGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawBatchGet(r) + resp.Resp = kvHandler{session}.handleKvRawBatchGet(r) case tikvrpc.CmdRawPut: r := req.RawPut() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawPutResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawPut(r) + resp.Resp = kvHandler{session}.handleKvRawPut(r) case tikvrpc.CmdRawBatchPut: r := req.RawBatchPut() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchPutResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawBatchPut(r) + resp.Resp = kvHandler{session}.handleKvRawBatchPut(r) case tikvrpc.CmdRawDelete: r := req.RawDelete() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawDeleteResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawDelete(r) + resp.Resp = kvHandler{session}.handleKvRawDelete(r) case tikvrpc.CmdRawBatchDelete: r := req.RawBatchDelete() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchDeleteResponse{RegionError: err} } - resp.Resp = handler.handleKvRawBatchDelete(r) + resp.Resp = kvHandler{session}.handleKvRawBatchDelete(r) case tikvrpc.CmdRawDeleteRange: r := req.RawDeleteRange() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawDeleteRangeResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawDeleteRange(r) + resp.Resp = kvHandler{session}.handleKvRawDeleteRange(r) case tikvrpc.CmdRawScan: r := req.RawScan() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawScanResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawScan(r) + resp.Resp = kvHandler{session}.handleKvRawScan(r) case tikvrpc.CmdUnsafeDestroyRange: panic("unimplemented") case tikvrpc.CmdRegisterLockObserver: @@ -1035,20 +918,20 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return nil, errors.New("unimplemented") case tikvrpc.CmdCop: r := req.Cop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &coprocessor.Response{RegionError: err} return resp, nil } - handler.rawStartKey = MvccKey(handler.startKey).Raw() - handler.rawEndKey = MvccKey(handler.endKey).Raw() + session.rawStartKey = MvccKey(session.startKey).Raw() + session.rawEndKey = MvccKey(session.endKey).Raw() var res *coprocessor.Response switch r.GetTp() { case kv.ReqTypeDAG: - res = handler.handleCopDAGRequest(r) + res = coprHandler{session}.handleCopDAGRequest(r) case kv.ReqTypeAnalyze: - res = handler.handleCopAnalyzeRequest(r) + res = coprHandler{session}.handleCopAnalyzeRequest(r) case kv.ReqTypeChecksum: - res = handler.handleCopChecksumRequest(r) + res = coprHandler{session}.handleCopChecksumRequest(r) default: panic(fmt.Sprintf("unknown coprocessor request type: %v", r.GetTp())) } @@ -1066,7 +949,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } }) r := req.BatchCop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &tikvrpc.BatchCopStreamResponse{ Tikv_BatchCoprocessorClient: &mockBathCopErrClient{Error: err}, BatchResponse: &coprocessor.BatchResponse{ @@ -1076,7 +959,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return resp, nil } ctx1, cancel := context.WithCancel(ctx) - batchCopStream, err := handler.handleBatchCopRequest(ctx1, r) + batchCopStream, err := coprHandler{session}.handleBatchCopRequest(ctx1, r) if err != nil { cancel() return nil, errors.Trace(err) @@ -1094,7 +977,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp = batchResp case tikvrpc.CmdCopStream: r := req.Cop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &tikvrpc.CopStreamResponse{ Tikv_CoprocessorStreamClient: &mockCopStreamErrClient{Error: err}, Response: &coprocessor.Response{ @@ -1103,10 +986,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } return resp, nil } - handler.rawStartKey = MvccKey(handler.startKey).Raw() - handler.rawEndKey = MvccKey(handler.endKey).Raw() + session.rawStartKey = MvccKey(session.startKey).Raw() + session.rawEndKey = MvccKey(session.endKey).Raw() ctx1, cancel := context.WithCancel(ctx) - copStream, err := handler.handleCopStream(ctx1, r) + copStream, err := coprHandler{session}.handleCopStream(ctx1, r) if err != nil { cancel() return nil, errors.Trace(err) @@ -1127,31 +1010,31 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp = streamResp case tikvrpc.CmdMvccGetByKey: r := req.MvccGetByKey() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.MvccGetByKeyResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleMvccGetByKey(r) + resp.Resp = kvHandler{session}.handleMvccGetByKey(r) case tikvrpc.CmdMvccGetByStartTs: r := req.MvccGetByStartTs() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.MvccGetByStartTsResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleMvccGetByStartTS(r) + resp.Resp = kvHandler{session}.handleMvccGetByStartTS(r) case tikvrpc.CmdSplitRegion: r := req.SplitRegion() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.SplitRegionResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleSplitRegion(r) + resp.Resp = kvHandler{session}.handleSplitRegion(r) // DebugGetRegionProperties is for fast analyze in mock tikv. case tikvrpc.CmdDebugGetRegionProperties: r := req.DebugGetRegionProperties() region, _ := c.Cluster.GetRegion(r.RegionId) var reqCtx kvrpcpb.Context - scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{ + scanResp := kvHandler{session}.handleKvScan(&kvrpcpb.ScanRequest{ Context: &reqCtx, StartKey: MvccKey(region.StartKey).Raw(), EndKey: MvccKey(region.EndKey).Raw(), diff --git a/store/mockstore/mocktikv/session.go b/store/mockstore/mocktikv/session.go new file mode 100644 index 0000000000000..4d5e8b61678d8 --- /dev/null +++ b/store/mockstore/mocktikv/session.go @@ -0,0 +1,146 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mocktikv + +import ( + "github.com/gogo/protobuf/proto" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/ddl/placement" +) + +// Session stores session scope rpc data. +type Session struct { + cluster *Cluster + mvccStore MVCCStore + + // storeID stores id for current request + storeID uint64 + // startKey is used for handling normal request. + startKey []byte + endKey []byte + // rawStartKey is used for handling coprocessor request. + rawStartKey []byte + rawEndKey []byte + // isolationLevel is used for current request. + isolationLevel kvrpcpb.IsolationLevel + resolvedLocks []uint64 +} + +func (s *Session) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error { + ctxPeer := ctx.GetPeer() + if ctxPeer != nil && ctxPeer.GetStoreId() != s.storeID { + return &errorpb.Error{ + Message: *proto.String("store not match"), + StoreNotMatch: &errorpb.StoreNotMatch{}, + } + } + region, leaderID := s.cluster.GetRegion(ctx.GetRegionId()) + // No region found. + if region == nil { + return &errorpb.Error{ + Message: *proto.String("region not found"), + RegionNotFound: &errorpb.RegionNotFound{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + var storePeer, leaderPeer *metapb.Peer + for _, p := range region.Peers { + if p.GetStoreId() == s.storeID { + storePeer = p + } + if p.GetId() == leaderID { + leaderPeer = p + } + } + // The Store does not contain a Peer of the Region. + if storePeer == nil { + return &errorpb.Error{ + Message: *proto.String("region not found"), + RegionNotFound: &errorpb.RegionNotFound{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + // No leader. + if leaderPeer == nil { + return &errorpb.Error{ + Message: *proto.String("no leader"), + NotLeader: &errorpb.NotLeader{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + // The Peer on the Store is not leader. If it's tiflash store , we pass this check. + if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(s.cluster.GetStore(storePeer.GetStoreId())) { + return &errorpb.Error{ + Message: *proto.String("not leader"), + NotLeader: &errorpb.NotLeader{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + Leader: leaderPeer, + }, + } + } + // Region epoch does not match. + if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) { + nextRegion, _ := s.cluster.GetRegionByKey(region.GetEndKey()) + currentRegions := []*metapb.Region{region} + if nextRegion != nil { + currentRegions = append(currentRegions, nextRegion) + } + return &errorpb.Error{ + Message: *proto.String("epoch not match"), + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: currentRegions, + }, + } + } + s.startKey, s.endKey = region.StartKey, region.EndKey + s.isolationLevel = ctx.IsolationLevel + s.resolvedLocks = ctx.ResolvedLocks + return nil +} + +func (s *Session) checkRequestSize(size int) *errorpb.Error { + // TiKV has a limitation on raft log size. + // mocktikv has no raft inside, so we check the request's size instead. + if size >= requestMaxSize { + return &errorpb.Error{ + RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{}, + } + } + return nil +} + +func (s *Session) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error { + if err := s.checkRequestContext(ctx); err != nil { + return err + } + return s.checkRequestSize(size) +} + +func (s *Session) checkKeyInRegion(key []byte) bool { + return regionContains(s.startKey, s.endKey, NewMvccKey(key)) +} + +func isTiFlashStore(store *metapb.Store) bool { + for _, l := range store.GetLabels() { + if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash { + return true + } + } + return false +} diff --git a/store/mockstore/unistore/cophandler/closure_exec.go b/store/mockstore/unistore/cophandler/closure_exec.go index 841a4f19e14f6..c2a52bec6ae12 100644 --- a/store/mockstore/unistore/cophandler/closure_exec.go +++ b/store/mockstore/unistore/cophandler/closure_exec.go @@ -77,6 +77,8 @@ func getExecutorListFromRootExec(rootExec *tipb.Executor) ([]*tipb.Executor, err currentExec = currentExec.Limit.Child case tipb.ExecType_TypeExchangeSender: currentExec = currentExec.ExchangeSender.Child + case tipb.ExecType_TypeSelection: + currentExec = currentExec.Selection.Child default: return nil, errors.New("unsupported executor type " + currentExec.Tp.String()) } diff --git a/store/mockstore/unistore/cophandler/mpp.go b/store/mockstore/unistore/cophandler/mpp.go index 9849f53921099..f12e4b7f41bdf 100644 --- a/store/mockstore/unistore/cophandler/mpp.go +++ b/store/mockstore/unistore/cophandler/mpp.go @@ -275,6 +275,7 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { } e.aggExprs = append(e.aggExprs, aggExpr) } + e.sc = b.sc for _, gby := range agg.GroupBy { ft := expression.PbTypeToFieldType(gby.FieldType) diff --git a/store/tikv/2pc.go b/store/tikv/2pc.go index e27b2f6416cd2..deee0df28bdac 100644 --- a/store/tikv/2pc.go +++ b/store/tikv/2pc.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/store/tikv/metrics" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/store/tikv/tikvrpc" + "github.com/pingcap/tidb/store/tikv/unionstore" "github.com/pingcap/tidb/store/tikv/util" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/execdetails" @@ -107,13 +108,13 @@ type twoPhaseCommitter struct { } type memBufferMutations struct { - storage tidbkv.MemBuffer - handles []tidbkv.MemKeyHandle + storage *unionstore.MemDB + handles []unionstore.MemKeyHandle } -func newMemBufferMutations(sizeHint int, storage tidbkv.MemBuffer) *memBufferMutations { +func newMemBufferMutations(sizeHint int, storage *unionstore.MemDB) *memBufferMutations { return &memBufferMutations{ - handles: make([]tidbkv.MemKeyHandle, 0, sizeHint), + handles: make([]unionstore.MemKeyHandle, 0, sizeHint), storage: storage, } } @@ -154,7 +155,7 @@ func (m *memBufferMutations) Slice(from, to int) CommitterMutations { } } -func (m *memBufferMutations) Push(op pb.Op, isPessimisticLock bool, handle tidbkv.MemKeyHandle) { +func (m *memBufferMutations) Push(op pb.Op, isPessimisticLock bool, handle unionstore.MemKeyHandle) { aux := uint16(op) << 1 if isPessimisticLock { aux |= 1 @@ -358,7 +359,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations() error { // due to `Op_CheckNotExists` doesn't prewrite lock, so mark those keys should not be used in commit-phase. op = pb.Op_CheckNotExists checkCnt++ - memBuf.UpdateFlags(key, tidbkv.SetPrewriteOnly) + memBuf.UpdateFlags(key, kv.SetPrewriteOnly) } else { // normal delete keys in optimistic txn can be delete without not exists checking // delete-your-writes keys in pessimistic txn can ensure must be no exists so can directly delete them @@ -785,7 +786,7 @@ func sendTxnHeartBeat(bo *Backoffer, store *KVStore, primary []byte, startTS, tt if err != nil { return 0, errors.Trace(err) } - resp, err := store.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := store.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return 0, errors.Trace(err) } diff --git a/store/tikv/cleanup.go b/store/tikv/cleanup.go index 1dd56f54f63e3..dc96ed32ab54c 100644 --- a/store/tikv/cleanup.go +++ b/store/tikv/cleanup.go @@ -40,7 +40,7 @@ func (actionCleanup) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batc Keys: batch.mutations.GetKeys(), StartVersion: c.startTS, }, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) - resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) + resp, err := c.store.SendReq(bo, req, batch.region, ReadTimeoutShort) if err != nil { return errors.Trace(err) } diff --git a/store/tikv/client.go b/store/tikv/client.go index e9db387f763a3..de5ef3b5377b1 100644 --- a/store/tikv/client.go +++ b/store/tikv/client.go @@ -56,7 +56,7 @@ var MaxRecvMsgSize = math.MaxInt64 // Timeout durations. var ( dialTimeout = 5 * time.Second - readTimeoutShort = 20 * time.Second // For requests that read/write several key-values. + ReadTimeoutShort = 20 * time.Second // For requests that read/write several key-values. ReadTimeoutMedium = 60 * time.Second // For requests that may need scan region. ReadTimeoutLong = 150 * time.Second // For requests that may need scan region multiple times. ReadTimeoutUltraLong = 3600 * time.Second // For requests that may scan many regions for tiflash. diff --git a/store/tikv/client_collapse.go b/store/tikv/client_collapse.go index 5fc99420c0012..e7f9cfadcf08b 100644 --- a/store/tikv/client_collapse.go +++ b/store/tikv/client_collapse.go @@ -74,7 +74,7 @@ func (r reqCollapse) tryCollapseRequest(ctx context.Context, addr string, req *t func (r reqCollapse) collapse(ctx context.Context, key string, sf *singleflight.Group, addr string, req *tikvrpc.Request, timeout time.Duration) (resp *tikvrpc.Response, err error) { rsC := sf.DoChan(key, func() (interface{}, error) { - return r.Client.SendRequest(context.Background(), addr, req, readTimeoutShort) // use resolveLock timeout. + return r.Client.SendRequest(context.Background(), addr, req, ReadTimeoutShort) // use resolveLock timeout. }) timer := time.NewTimer(timeout) defer timer.Stop() diff --git a/store/tikv/commit.go b/store/tikv/commit.go index 54dad304705df..f1e9da6f5e103 100644 --- a/store/tikv/commit.go +++ b/store/tikv/commit.go @@ -48,7 +48,7 @@ func (actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch }, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) sender := NewRegionRequestSender(c.store.regionCache, c.store.client) - resp, err := sender.SendReq(bo, req, batch.region, readTimeoutShort) + resp, err := sender.SendReq(bo, req, batch.region, ReadTimeoutShort) // If we fail to receive response for the request that commits primary key, it will be undetermined whether this // transaction has been successfully committed. diff --git a/store/tikv/kv/keyflags.go b/store/tikv/kv/keyflags.go new file mode 100644 index 0000000000000..a98330f080f71 --- /dev/null +++ b/store/tikv/kv/keyflags.go @@ -0,0 +1,129 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +// KeyFlags are metadata associated with key +type KeyFlags uint8 + +const ( + flagPresumeKNE KeyFlags = 1 << iota + flagKeyLocked + flagNeedLocked + flagKeyLockedValExist + flagNeedCheckExists + flagPrewriteOnly + flagIgnoredIn2PC + + persistentFlags = flagKeyLocked | flagKeyLockedValExist +) + +// HasPresumeKeyNotExists returns whether the associated key use lazy check. +func (f KeyFlags) HasPresumeKeyNotExists() bool { + return f&flagPresumeKNE != 0 +} + +// HasLocked returns whether the associated key has acquired pessimistic lock. +func (f KeyFlags) HasLocked() bool { + return f&flagKeyLocked != 0 +} + +// HasNeedLocked return whether the key needed to be locked +func (f KeyFlags) HasNeedLocked() bool { + return f&flagNeedLocked != 0 +} + +// HasLockedValueExists returns whether the value exists when key locked. +func (f KeyFlags) HasLockedValueExists() bool { + return f&flagKeyLockedValExist != 0 +} + +// HasNeedCheckExists returns whether the key need to check existence when it has been locked. +func (f KeyFlags) HasNeedCheckExists() bool { + return f&flagNeedCheckExists != 0 +} + +// HasPrewriteOnly returns whether the key should be used in 2pc commit phase. +func (f KeyFlags) HasPrewriteOnly() bool { + return f&flagPrewriteOnly != 0 +} + +// HasIgnoredIn2PC returns whether the key will be ignored in 2pc. +func (f KeyFlags) HasIgnoredIn2PC() bool { + return f&flagIgnoredIn2PC != 0 +} + +// AndPersistent returns the value of current flags&persistentFlags +func (f KeyFlags) AndPersistent() KeyFlags { + return f & persistentFlags +} + +// ApplyFlagsOps applys flagspos to origin. +func ApplyFlagsOps(origin KeyFlags, ops ...FlagsOp) KeyFlags { + for _, op := range ops { + switch op { + case SetPresumeKeyNotExists: + origin |= flagPresumeKNE | flagNeedCheckExists + case DelPresumeKeyNotExists: + origin &= ^(flagPresumeKNE | flagNeedCheckExists) + case SetKeyLocked: + origin |= flagKeyLocked + case DelKeyLocked: + origin &= ^flagKeyLocked + case SetNeedLocked: + origin |= flagNeedLocked + case DelNeedLocked: + origin &= ^flagNeedLocked + case SetKeyLockedValueExists: + origin |= flagKeyLockedValExist + case DelNeedCheckExists: + origin &= ^flagNeedCheckExists + case SetKeyLockedValueNotExists: + origin &= ^flagKeyLockedValExist + case SetPrewriteOnly: + origin |= flagPrewriteOnly + case SetIgnoredIn2PC: + origin |= flagIgnoredIn2PC + } + } + return origin +} + +// FlagsOp describes KeyFlags modify operation. +type FlagsOp uint16 + +const ( + // SetPresumeKeyNotExists marks the existence of the associated key is checked lazily. + // Implies KeyFlags.HasNeedCheckExists() == true. + SetPresumeKeyNotExists FlagsOp = 1 << iota + // DelPresumeKeyNotExists reverts SetPresumeKeyNotExists. + DelPresumeKeyNotExists + // SetKeyLocked marks the associated key has acquired lock. + SetKeyLocked + // DelKeyLocked reverts SetKeyLocked. + DelKeyLocked + // SetNeedLocked marks the associated key need to be acquired lock. + SetNeedLocked + // DelNeedLocked reverts SetKeyNeedLocked. + DelNeedLocked + // SetKeyLockedValueExists marks the value exists when key has been locked in Transaction.LockKeys. + SetKeyLockedValueExists + // SetKeyLockedValueNotExists marks the value doesn't exists when key has been locked in Transaction.LockKeys. + SetKeyLockedValueNotExists + // DelNeedCheckExists marks the key no need to be checked in Transaction.LockKeys. + DelNeedCheckExists + // SetPrewriteOnly marks the key shouldn't be used in 2pc commit phase. + SetPrewriteOnly + // SetIgnoredIn2PC marks the key will be ignored in 2pc. + SetIgnoredIn2PC +) diff --git a/store/tikv/lock_resolver.go b/store/tikv/lock_resolver.go index fb66b1c9be04a..27c4ac32bdd55 100644 --- a/store/tikv/lock_resolver.go +++ b/store/tikv/lock_resolver.go @@ -285,7 +285,7 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{TxnInfos: listTxnInfos}) startTime = time.Now() - resp, err := lr.store.SendReq(bo, req, loc, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, loc, ReadTimeoutShort) if err != nil { return false, errors.Trace(err) } @@ -593,7 +593,7 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte if err != nil { return status, errors.Trace(err) } - resp, err := lr.store.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return status, errors.Trace(err) } @@ -729,7 +729,7 @@ func (lr *LockResolver) checkSecondaries(bo *Backoffer, txnID uint64, curKeys [] } req := tikvrpc.NewRequest(tikvrpc.CmdCheckSecondaryLocks, checkReq) metrics.LockResolverCountWithQueryCheckSecondaryLocks.Inc() - resp, err := lr.store.SendReq(bo, req, curRegionID, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, curRegionID, ReadTimeoutShort) if err != nil { return errors.Trace(err) } @@ -859,7 +859,7 @@ func (lr *LockResolver) resolveRegionLocks(bo *Backoffer, l *Lock, region Region lreq.Keys = keys req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, lreq) - resp, err := lr.store.SendReq(bo, req, region, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, region, ReadTimeoutShort) if err != nil { return errors.Trace(err) } @@ -928,7 +928,7 @@ func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, li lreq.Keys = [][]byte{l.Key} } req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, lreq) - resp, err := lr.store.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return errors.Trace(err) } @@ -979,7 +979,7 @@ func (lr *LockResolver) resolvePessimisticLock(bo *Backoffer, l *Lock, cleanRegi Keys: [][]byte{l.Key}, } req := tikvrpc.NewRequest(tikvrpc.CmdPessimisticRollback, pessimisticRollbackReq) - resp, err := lr.store.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := lr.store.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return errors.Trace(err) } diff --git a/store/tikv/lock_test.go b/store/tikv/lock_test.go index b61bfcd111d7e..6e8f98708d301 100644 --- a/store/tikv/lock_test.go +++ b/store/tikv/lock_test.go @@ -411,7 +411,7 @@ func (s *testLockSuite) mustGetLock(c *C, key []byte) *Lock { }) loc, err := s.store.regionCache.LocateKey(bo, key) c.Assert(err, IsNil) - resp, err := s.store.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := s.store.SendReq(bo, req, loc.Region, ReadTimeoutShort) c.Assert(err, IsNil) c.Assert(resp.Resp, NotNil) keyErr := resp.Resp.(*kvrpcpb.GetResponse).GetError() diff --git a/store/tikv/pessimistic.go b/store/tikv/pessimistic.go index 2f85cc723a7c8..799e38076be57 100644 --- a/store/tikv/pessimistic.go +++ b/store/tikv/pessimistic.go @@ -108,7 +108,7 @@ func (action actionPessimisticLock) handleSingleBatch(c *twoPhaseCommitter, bo * return kv.ErrWriteConflict }) startTime := time.Now() - resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) + resp, err := c.store.SendReq(bo, req, batch.region, ReadTimeoutShort) if action.LockCtx.Stats != nil { atomic.AddInt64(&action.LockCtx.Stats.LockRPCTime, int64(time.Since(startTime))) atomic.AddInt64(&action.LockCtx.Stats.LockRPCCount, 1) @@ -210,7 +210,7 @@ func (actionPessimisticRollback) handleSingleBatch(c *twoPhaseCommitter, bo *Bac ForUpdateTs: c.forUpdateTS, Keys: batch.mutations.GetKeys(), }) - resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) + resp, err := c.store.SendReq(bo, req, batch.region, ReadTimeoutShort) if err != nil { return errors.Trace(err) } diff --git a/store/tikv/prewrite.go b/store/tikv/prewrite.go index 5583a1e8525e2..43db9cb11500a 100644 --- a/store/tikv/prewrite.go +++ b/store/tikv/prewrite.go @@ -158,7 +158,7 @@ func (action actionPrewrite) handleSingleBatch(c *twoPhaseCommitter, bo *Backoff req := c.buildPrewriteRequest(batch, txnSize) for { sender := NewRegionRequestSender(c.store.regionCache, c.store.client) - resp, err := sender.SendReq(bo, req, batch.region, readTimeoutShort) + resp, err := sender.SendReq(bo, req, batch.region, ReadTimeoutShort) // If we fail to receive response for async commit prewrite, it will be undetermined whether this // transaction has been successfully committed. diff --git a/store/tikv/rawkv.go b/store/tikv/rawkv.go index dd5b621bb1092..0b07b4ddf014e 100644 --- a/store/tikv/rawkv.go +++ b/store/tikv/rawkv.go @@ -362,7 +362,7 @@ func (c *RawKVClient) sendReq(key []byte, req *tikvrpc.Request, reverse bool) (* if err != nil { return nil, nil, errors.Trace(err) } - resp, err := sender.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := sender.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return nil, nil, errors.Trace(err) } @@ -442,7 +442,7 @@ func (c *RawKVClient) doBatchReq(bo *Backoffer, batch batch, cmdType tikvrpc.Cmd } sender := NewRegionRequestSender(c.regionCache, c.rpcClient) - resp, err := sender.SendReq(bo, req, batch.regionID, readTimeoutShort) + resp, err := sender.SendReq(bo, req, batch.regionID, ReadTimeoutShort) batchResp := singleBatchResp{} if err != nil { @@ -507,7 +507,7 @@ func (c *RawKVClient) sendDeleteRangeReq(startKey []byte, endKey []byte) (*tikvr EndKey: actualEndKey, }) - resp, err := sender.SendReq(bo, req, loc.Region, readTimeoutShort) + resp, err := sender.SendReq(bo, req, loc.Region, ReadTimeoutShort) if err != nil { return nil, nil, errors.Trace(err) } @@ -612,7 +612,7 @@ func (c *RawKVClient) doBatchPut(bo *Backoffer, batch batch) error { req := tikvrpc.NewRequest(tikvrpc.CmdRawBatchPut, &kvrpcpb.RawBatchPutRequest{Pairs: kvPair}) sender := NewRegionRequestSender(c.regionCache, c.rpcClient) - resp, err := sender.SendReq(bo, req, batch.regionID, readTimeoutShort) + resp, err := sender.SendReq(bo, req, batch.regionID, ReadTimeoutShort) if err != nil { return errors.Trace(err) } diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 4429e35cd102a..0c350e2f73b7a 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -58,15 +58,38 @@ var RegionCacheTTLSec int64 = 600 const ( updated int32 = iota // region is updated and no need to reload. - needSync // need sync new region info. + needSync // need sync new region info. +) + +// InvalidReason is the reason why a cached region is invalidated. +// The region cache may take different strategies to handle different reasons. +// For example, when a cached region is invalidated due to no leader, region cache +// will always access to a different peer. +type InvalidReason int32 + +const ( + // Ok indicates the cached region is valid + Ok InvalidReason = iota + // NoLeader indicates it's invalidated due to no leader + NoLeader + // RegionNotFound indicates it's invalidated due to region not found in the store + RegionNotFound + // EpochNotMatch indicates it's invalidated due to epoch not match + EpochNotMatch + // StoreNotFound indicates it's invalidated due to store not found in PD + StoreNotFound + // Other indicates it's invalidated due to other reasons, e.g., the store + // is removed from the cluster, fail to send requests to the store. + Other ) // Region presents kv region type Region struct { - meta *metapb.Region // raw region meta from PD immutable after init - store unsafe.Pointer // point to region store info, see RegionStore - syncFlag int32 // region need be sync in next turn - lastAccess int64 // last region access time, see checkRegionCacheTTL + meta *metapb.Region // raw region meta from PD immutable after init + store unsafe.Pointer // point to region store info, see RegionStore + syncFlag int32 // region need be sync in next turn + lastAccess int64 // last region access time, see checkRegionCacheTTL + invalidReason InvalidReason // the reason why the region is invalidated } // AccessIndex represent the index for accessIndex array @@ -203,7 +226,7 @@ func (r *Region) compareAndSwapStore(oldStore, newStore *RegionStore) bool { func (r *Region) checkRegionCacheTTL(ts int64) bool { // Only consider use percentage on this failpoint, for example, "2%return" failpoint.Inject("invalidateRegionCache", func() { - r.invalidate() + r.invalidate(Other) }) for { lastAccess := atomic.LoadInt64(&r.lastAccess) @@ -217,8 +240,9 @@ func (r *Region) checkRegionCacheTTL(ts int64) bool { } // invalidate invalidates a region, next time it will got null result. -func (r *Region) invalidate() { +func (r *Region) invalidate(reason InvalidReason) { metrics.RegionCacheCounterWithInvalidateRegionFromCacheOK.Inc() + atomic.StoreInt32((*int32)(&r.invalidReason), int32(reason)) atomic.StoreInt64(&r.lastAccess, invalidatedLastAccessTime) } @@ -456,13 +480,13 @@ func (c *RegionCache) GetTiKVRPCContext(bo *Backoffer, id RegionVerID, replicaRe }) if store == nil || len(addr) == 0 { // Store not found, region must be out of date. - cachedRegion.invalidate() + cachedRegion.invalidate(StoreNotFound) return nil, nil } storeFailEpoch := atomic.LoadUint32(&store.epoch) if storeFailEpoch != regionStore.storeEpochs[storeIdx] { - cachedRegion.invalidate() + cachedRegion.invalidate(Other) logutil.BgLogger().Info("invalidate current region, because others failed on same store", zap.Uint64("region", id.GetID()), zap.String("store", store.addr)) @@ -506,7 +530,8 @@ func (c *RegionCache) GetTiKVRPCContext(bo *Backoffer, id RegionVerID, replicaRe // GetTiFlashRPCContext returns RPCContext for a region must access flash store. If it returns nil, the region // must be out of date and already dropped from cache or not flash store found. -func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID) (*RPCContext, error) { +// `loadBalance` is an option. For MPP and batch cop, it is pointless and might cause try the failed store repeatly. +func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID, loadBalance bool) (*RPCContext, error) { ts := time.Now().Unix() cachedRegion := c.getCachedRegionWithRLock(id) @@ -520,7 +545,12 @@ func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID) (*RPCC regionStore := cachedRegion.getStore() // sIdx is for load balance of TiFlash store. - sIdx := int(atomic.AddInt32(®ionStore.workTiFlashIdx, 1)) + var sIdx int + if loadBalance { + sIdx = int(atomic.AddInt32(®ionStore.workTiFlashIdx, 1)) + } else { + sIdx = int(atomic.LoadInt32(®ionStore.workTiFlashIdx)) + } for i := 0; i < regionStore.accessStoreNum(TiFlashOnly); i++ { accessIdx := AccessIndex((sIdx + i) % regionStore.accessStoreNum(TiFlashOnly)) storeIdx, store := regionStore.accessStore(TiFlashOnly, accessIdx) @@ -529,7 +559,7 @@ func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID) (*RPCC return nil, err } if len(addr) == 0 { - cachedRegion.invalidate() + cachedRegion.invalidate(StoreNotFound) return nil, nil } if store.getResolveState() == needCheck { @@ -540,7 +570,7 @@ func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID) (*RPCC peer := cachedRegion.meta.Peers[storeIdx] storeFailEpoch := atomic.LoadUint32(&store.epoch) if storeFailEpoch != regionStore.storeEpochs[storeIdx] { - cachedRegion.invalidate() + cachedRegion.invalidate(Other) logutil.BgLogger().Info("invalidate current region, because others failed on same store", zap.Uint64("region", id.GetID()), zap.String("store", store.addr)) @@ -559,7 +589,7 @@ func (c *RegionCache) GetTiFlashRPCContext(bo *Backoffer, id RegionVerID) (*RPCC }, nil } - cachedRegion.invalidate() + cachedRegion.invalidate(Other) return nil, nil } @@ -901,11 +931,16 @@ func (c *RegionCache) BatchLoadRegionsFromKey(bo *Backoffer, startKey []byte, co // InvalidateCachedRegion removes a cached Region. func (c *RegionCache) InvalidateCachedRegion(id RegionVerID) { + c.InvalidateCachedRegionWithReason(id, Other) +} + +// InvalidateCachedRegionWithReason removes a cached Region with the reason why it's invalidated. +func (c *RegionCache) InvalidateCachedRegionWithReason(id RegionVerID, reason InvalidReason) { cachedRegion := c.getCachedRegionWithRLock(id) if cachedRegion == nil { return } - cachedRegion.invalidate() + cachedRegion.invalidate(reason) } // UpdateLeader update some region cache with newer leader info. @@ -932,7 +967,7 @@ func (c *RegionCache) UpdateLeader(regionID RegionVerID, leaderStoreID uint64, c zap.Uint64("regionID", regionID.GetID()), zap.Int("currIdx", int(currentPeerIdx)), zap.Uint64("leaderStoreID", leaderStoreID)) - r.invalidate() + r.invalidate(StoreNotFound) } else { logutil.BgLogger().Info("switch region leader to specific leader due to kv return NotLeader", zap.Uint64("regionID", regionID.GetID()), @@ -942,13 +977,25 @@ func (c *RegionCache) UpdateLeader(regionID RegionVerID, leaderStoreID uint64, c } // insertRegionToCache tries to insert the Region to cache. +// It should be protected by c.mu.Lock(). func (c *RegionCache) insertRegionToCache(cachedRegion *Region) { old := c.mu.sorted.ReplaceOrInsert(newBtreeItem(cachedRegion)) if old != nil { + store := cachedRegion.getStore() + oldRegion := old.(*btreeItem).cachedRegion + oldRegionStore := oldRegion.getStore() + // Joint consensus is enabled in v5.0, which is possible to make a leader step down as a learner during a conf change. + // And if hibernate region is enabled, after the leader step down, there can be a long time that there is no leader + // in the region and the leader info in PD is stale until requests are sent to followers or hibernate timeout. + // To solve it, one solution is always to try a different peer if the invalid reason of the old cached region is no-leader. + // There is a small probability that the current peer who reports no-leader becomes a leader and TiDB has to retry once in this case. + if InvalidReason(atomic.LoadInt32((*int32)(&oldRegion.invalidReason))) == NoLeader { + store.workTiKVIdx = (oldRegionStore.workTiKVIdx + 1) % AccessIndex(store.accessStoreNum(TiKVOnly)) + } // Don't refresh TiFlash work idx for region. Otherwise, it will always goto a invalid store which // is under transferring regions. - atomic.StoreInt32(&cachedRegion.getStore().workTiFlashIdx, atomic.LoadInt32(&old.(*btreeItem).cachedRegion.getStore().workTiFlashIdx)) - delete(c.mu.regions, old.(*btreeItem).cachedRegion.VerID()) + store.workTiFlashIdx = atomic.LoadInt32(&oldRegionStore.workTiFlashIdx) + delete(c.mu.regions, oldRegion.VerID()) } c.mu.regions[cachedRegion.VerID()] = cachedRegion } @@ -1364,7 +1411,7 @@ func (c *RegionCache) OnRegionEpochNotMatch(bo *Backoffer, ctx *RPCContext, curr if needInvalidateOld { cachedRegion, ok := c.mu.regions[ctx.Region] if ok { - cachedRegion.invalidate() + cachedRegion.invalidate(EpochNotMatch) } } return nil diff --git a/store/tikv/region_cache_test.go b/store/tikv/region_cache_test.go index c47b154ba0bac..5e27c3fe0434b 100644 --- a/store/tikv/region_cache_test.go +++ b/store/tikv/region_cache_test.go @@ -876,7 +876,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash(c *C) { c.Assert(lctx.Peer.Id, Equals, peer3) // epoch-not-match on tiflash - ctxTiFlash, err := s.cache.GetTiFlashRPCContext(s.bo, loc1.Region) + ctxTiFlash, err := s.cache.GetTiFlashRPCContext(s.bo, loc1.Region, true) c.Assert(err, IsNil) c.Assert(ctxTiFlash.Peer.Id, Equals, s.peer1) ctxTiFlash.Peer.Role = metapb.PeerRole_Learner @@ -1392,6 +1392,25 @@ func (s *testRegionCacheSuite) TestContainsByEnd(c *C) { c.Assert(createSampleRegion([]byte{10}, []byte{20}).ContainsByEnd([]byte{30}), IsFalse) } +func (s *testRegionCacheSuite) TestSwitchPeerWhenNoLeader(c *C) { + var prevCtx *RPCContext + for i := 0; i <= len(s.cluster.GetAllStores()); i++ { + loc, err := s.cache.LocateKey(s.bo, []byte("a")) + c.Assert(err, IsNil) + ctx, err := s.cache.GetTiKVRPCContext(s.bo, loc.Region, kv.ReplicaReadLeader, 0) + c.Assert(err, IsNil) + if prevCtx == nil { + c.Assert(i, Equals, 0) + } else { + c.Assert(ctx.AccessIdx, Not(Equals), prevCtx.AccessIdx) + c.Assert(ctx.Peer, Not(DeepEquals), prevCtx.Peer) + } + s.cache.InvalidateCachedRegionWithReason(loc.Region, NoLeader) + c.Assert(s.cache.getCachedRegionWithRLock(loc.Region).invalidReason, Equals, NoLeader) + prevCtx = ctx + } +} + func BenchmarkOnRequestFail(b *testing.B) { /* This benchmark simulate many concurrent requests call OnSendRequestFail method diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index 59db371394170..f8e5ec7e5c2b7 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -203,7 +203,7 @@ func (s *RegionRequestSender) getRPCContext( } return s.regionCache.GetTiKVRPCContext(bo, regionID, req.ReplicaReadType, seed, opts...) case tidbkv.TiFlash: - return s.regionCache.GetTiFlashRPCContext(bo, regionID) + return s.regionCache.GetTiFlashRPCContext(bo, regionID, true) case tidbkv.TiDB: return &RPCContext{Addr: s.storeAddr}, nil default: @@ -655,7 +655,7 @@ func (s *RegionRequestSender) onRegionError(bo *Backoffer, ctx *RPCContext, seed // the Raft group is in an election, but it's possible that the peer is // isolated and removed from the Raft group. So it's necessary to reload // the region from PD. - s.regionCache.InvalidateCachedRegion(ctx.Region) + s.regionCache.InvalidateCachedRegionWithReason(ctx.Region, NoLeader) if err = bo.Backoff(BoRegionMiss, errors.Errorf("not leader: %v, ctx: %v", notLeader, ctx)); err != nil { return false, errors.Trace(err) } diff --git a/store/tikv/region_request_test.go b/store/tikv/region_request_test.go index 7fcfa4d21a855..2b488861077f5 100644 --- a/store/tikv/region_request_test.go +++ b/store/tikv/region_request_test.go @@ -572,6 +572,34 @@ func (s *testRegionRequestToSingleStoreSuite) TestOnMaxTimestampNotSyncedError(c }() } +func (s *testRegionRequestToThreeStoresSuite) TestSwitchPeerWhenNoLeader(c *C) { + var leaderAddr string + s.regionRequestSender.client = &fnClient{func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (response *tikvrpc.Response, err error) { + if leaderAddr == "" { + leaderAddr = addr + } + // Returns OK when switches to a different peer. + if leaderAddr != addr { + return &tikvrpc.Response{Resp: &kvrpcpb.RawPutResponse{}}, nil + } + return &tikvrpc.Response{Resp: &kvrpcpb.RawPutResponse{ + RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}, + }}, nil + }} + + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) + + bo := NewBackofferWithVars(context.Background(), 5, nil) + loc, err := s.cache.LocateKey(s.bo, []byte("key")) + c.Assert(err, IsNil) + resp, err := s.regionRequestSender.SendReq(bo, req, loc.Region, time.Second) + c.Assert(err, IsNil) + c.Assert(resp, NotNil) +} + func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore(c *C) (*Store, string) { region, err := s.regionRequestSender.regionCache.findRegionByKey(s.bo, []byte("a"), false) c.Assert(err, IsNil) diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index c504e3f1e1fbb..df096c9c49fdc 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -458,7 +458,7 @@ func (s *KVSnapshot) get(ctx context.Context, bo *Backoffer, k tidbkv.Key) ([]by if err != nil { return nil, errors.Trace(err) } - resp, _, _, err := cli.SendReqCtx(bo, req, loc.Region, readTimeoutShort, tidbkv.TiKV, "", ops...) + resp, _, _, err := cli.SendReqCtx(bo, req, loc.Region, ReadTimeoutShort, tidbkv.TiKV, "", ops...) if err != nil { return nil, errors.Trace(err) } diff --git a/store/tikv/split_region.go b/store/tikv/split_region.go index 0357e55a26a4a..002b468102cdf 100644 --- a/store/tikv/split_region.go +++ b/store/tikv/split_region.go @@ -122,7 +122,7 @@ func (s *KVStore) batchSendSingleRegion(bo *Backoffer, batch batch, scatter bool }) sender := NewRegionRequestSender(s.regionCache, s.client) - resp, err := sender.SendReq(bo, req, batch.regionID, readTimeoutShort) + resp, err := sender.SendReq(bo, req, batch.regionID, ReadTimeoutShort) batchResp := singleBatchResp{resp: resp} if err != nil { diff --git a/store/tikv/test_probe.go b/store/tikv/test_probe.go index 63b9e56e24e0b..afdb736d1dc5e 100644 --- a/store/tikv/test_probe.go +++ b/store/tikv/test_probe.go @@ -20,8 +20,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" pb "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/tikvrpc" + "github.com/pingcap/tidb/store/tikv/unionstore" ) // StoreProbe wraps KVSTore and exposes internal states for testing purpose. @@ -50,7 +50,7 @@ func (txn TxnProbe) GetCommitTS() uint64 { } // GetUnionStore returns transaction's embedded unionstore. -func (txn TxnProbe) GetUnionStore() kv.UnionStore { +func (txn TxnProbe) GetUnionStore() *unionstore.KVUnionStore { return txn.us } @@ -136,6 +136,11 @@ func (c CommitterProbe) GetCommitTS() uint64 { return c.commitTS } +// GetMinCommitTS returns the minimal commit ts can be used. +func (c CommitterProbe) GetMinCommitTS() uint64 { + return c.minCommitTS +} + // SetMinCommitTS sets the minimal commit ts can be used. func (c CommitterProbe) SetMinCommitTS(ts uint64) { c.minCommitTS = ts @@ -146,6 +151,11 @@ func (c CommitterProbe) SetSessionID(id uint64) { c.sessionID = id } +// GetForUpdateTS returns the pessimistic ForUpdate ts. +func (c CommitterProbe) GetForUpdateTS() uint64 { + return c.forUpdateTS +} + // SetForUpdateTS sets pessimistic ForUpdate ts. func (c CommitterProbe) SetForUpdateTS(ts uint64) { c.forUpdateTS = ts @@ -167,6 +177,11 @@ func (c CommitterProbe) SetTxnSize(sz int) { c.lockTTL = txnLockTTL(c.txn.startTime, sz) } +// SetUseAsyncCommit enables async commit feature. +func (c CommitterProbe) SetUseAsyncCommit() { + c.useAsyncCommit = 1 +} + // Execute runs the commit process. func (c CommitterProbe) Execute(ctx context.Context) error { return c.execute(ctx) diff --git a/store/tikv/tests/2pc_test.go b/store/tikv/tests/2pc_test.go index 5970a7f36d04e..507a03414658d 100644 --- a/store/tikv/tests/2pc_test.go +++ b/store/tikv/tests/2pc_test.go @@ -150,7 +150,7 @@ func (s *testCommitterSuite) TestDeleteYourWritesTTL(c *C) { { txn := s.begin(c) - err := txn.GetMemBuffer().SetWithFlags(tidbkv.Key("bb"), []byte{0}, tidbkv.SetPresumeKeyNotExists) + err := txn.GetMemBuffer().SetWithFlags(tidbkv.Key("bb"), []byte{0}, kv.SetPresumeKeyNotExists) c.Assert(err, IsNil) err = txn.Set(tidbkv.Key("ba"), []byte{1}) c.Assert(err, IsNil) @@ -165,7 +165,7 @@ func (s *testCommitterSuite) TestDeleteYourWritesTTL(c *C) { { txn := s.begin(c) - err := txn.GetMemBuffer().SetWithFlags(tidbkv.Key("dd"), []byte{0}, tidbkv.SetPresumeKeyNotExists) + err := txn.GetMemBuffer().SetWithFlags(tidbkv.Key("dd"), []byte{0}, kv.SetPresumeKeyNotExists) c.Assert(err, IsNil) err = txn.Set(tidbkv.Key("de"), []byte{1}) c.Assert(err, IsNil) @@ -637,7 +637,7 @@ func (s *testCommitterSuite) TestUnsetPrimaryKey(c *C) { txn = s.begin(c) txn.SetOption(kv.Pessimistic, true) _, _ = txn.GetUnionStore().Get(context.TODO(), key) - c.Assert(txn.GetMemBuffer().SetWithFlags(key, key, tidbkv.SetPresumeKeyNotExists), IsNil) + c.Assert(txn.GetMemBuffer().SetWithFlags(key, key, kv.SetPresumeKeyNotExists), IsNil) lockCtx := &tidbkv.LockCtx{ForUpdateTS: txn.StartTS(), WaitStartTime: time.Now()} err := txn.LockKeys(context.Background(), lockCtx, key) c.Assert(err, NotNil) @@ -748,7 +748,7 @@ func (s *testCommitterSuite) TestDeleteYourWriteCauseGhostPrimary(c *C) { txn1.DelOption(kv.Pessimistic) txn1.ClearStoreTxnLatches() txn1.Get(context.Background(), k1) - txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, tidbkv.SetPresumeKeyNotExists) + txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, kv.SetPresumeKeyNotExists) txn1.Set(k2, []byte{1}) txn1.Set(k3, []byte{2}) txn1.Delete(k1) @@ -789,11 +789,11 @@ func (s *testCommitterSuite) TestDeleteAllYourWrites(c *C) { txn1 := s.begin(c) txn1.DelOption(kv.Pessimistic) txn1.ClearStoreTxnLatches() - txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, tidbkv.SetPresumeKeyNotExists) + txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, kv.SetPresumeKeyNotExists) txn1.Delete(k1) - txn1.GetMemBuffer().SetWithFlags(k2, []byte{1}, tidbkv.SetPresumeKeyNotExists) + txn1.GetMemBuffer().SetWithFlags(k2, []byte{1}, kv.SetPresumeKeyNotExists) txn1.Delete(k2) - txn1.GetMemBuffer().SetWithFlags(k3, []byte{2}, tidbkv.SetPresumeKeyNotExists) + txn1.GetMemBuffer().SetWithFlags(k3, []byte{2}, kv.SetPresumeKeyNotExists) txn1.Delete(k3) err1 := txn1.Commit(context.Background()) c.Assert(err1, IsNil) @@ -809,7 +809,7 @@ func (s *testCommitterSuite) TestDeleteAllYourWritesWithSFU(c *C) { txn1 := s.begin(c) txn1.DelOption(kv.Pessimistic) txn1.ClearStoreTxnLatches() - txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, tidbkv.SetPresumeKeyNotExists) + txn1.GetMemBuffer().SetWithFlags(k1, []byte{0}, kv.SetPresumeKeyNotExists) txn1.Delete(k1) err := txn1.LockKeys(context.Background(), &tidbkv.LockCtx{}, k2, k3) // select * from t where x in (k2, k3) for update c.Assert(err, IsNil) diff --git a/store/tikv/prewrite_test.go b/store/tikv/tests/prewrite_test.go similarity index 68% rename from store/tikv/prewrite_test.go rename to store/tikv/tests/prewrite_test.go index b9287c0d7e620..6f75959b4afe4 100644 --- a/store/tikv/prewrite_test.go +++ b/store/tikv/tests/prewrite_test.go @@ -11,16 +11,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tikv +package tikv_test import ( . "github.com/pingcap/check" pb "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/store/mockstore/unistore" + "github.com/pingcap/tidb/store/tikv" ) type testPrewriteSuite struct { - store *KVStore + store *tikv.KVStore } var _ = Suite(&testPrewriteSuite{}) @@ -29,7 +30,7 @@ func (s *testPrewriteSuite) SetUpTest(c *C) { client, pdClient, cluster, err := unistore.New("") c.Assert(err, IsNil) unistore.BootstrapWithSingleStore(cluster) - store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0) + store, err := tikv.NewTestTiKVStore(client, pdClient, nil, nil, 0) c.Assert(err, IsNil) s.store = store } @@ -37,31 +38,30 @@ func (s *testPrewriteSuite) SetUpTest(c *C) { func (s *testPrewriteSuite) TestSetMinCommitTSInAsyncCommit(c *C) { t, err := s.store.Begin() c.Assert(err, IsNil) - txn := t + txn := tikv.TxnProbe{KVTxn: t} err = txn.Set([]byte("k"), []byte("v")) c.Assert(err, IsNil) - committer, err := newTwoPhaseCommitterWithInit(txn, 1) + committer, err := txn.NewCommitter(1) c.Assert(err, IsNil) - committer.useAsyncCommit = 1 + committer.SetUseAsyncCommit() buildRequest := func() *pb.PrewriteRequest { - batch := batchMutations{mutations: committer.mutations} - req := committer.buildPrewriteRequest(batch, 1) + req := committer.BuildPrewriteRequest(1, 1, 1, committer.GetMutations(), 1) return req.Req.(*pb.PrewriteRequest) } // no forUpdateTS req := buildRequest() - c.Assert(req.MinCommitTs, Equals, txn.startTS+1) + c.Assert(req.MinCommitTs, Equals, txn.StartTS()+1) // forUpdateTS is set - committer.forUpdateTS = txn.startTS + (5 << 18) + committer.SetForUpdateTS(txn.StartTS() + (5 << 18)) req = buildRequest() - c.Assert(req.MinCommitTs, Equals, committer.forUpdateTS+1) + c.Assert(req.MinCommitTs, Equals, committer.GetForUpdateTS()+1) // minCommitTS is set - committer.minCommitTS = txn.startTS + (10 << 18) + committer.SetMinCommitTS(txn.StartTS() + (10 << 18)) req = buildRequest() - c.Assert(req.MinCommitTs, Equals, committer.minCommitTS) + c.Assert(req.MinCommitTs, Equals, committer.GetMinCommitTS()) } diff --git a/store/tikv/txn.go b/store/tikv/txn.go index cdd738fbf5ca0..d5dc1cee2e88e 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/store/tikv/logutil" "github.com/pingcap/tidb/store/tikv/metrics" + "github.com/pingcap/tidb/store/tikv/unionstore" "github.com/pingcap/tidb/store/tikv/util" "github.com/pingcap/tidb/util/execdetails" "go.uber.org/zap" @@ -49,7 +50,7 @@ type SchemaAmender interface { // KVTxn contains methods to interact with a TiKV transaction. type KVTxn struct { snapshot *KVSnapshot - us tidbkv.UnionStore + us *unionstore.KVUnionStore store *KVStore // for connection to region. startTS uint64 startTime time.Time // Monotonic timestamp for recording txn time consuming. @@ -84,7 +85,7 @@ func newTiKVTxnWithStartTS(store *KVStore, txnScope string, startTS uint64, repl snapshot := newTiKVSnapshot(store, startTS, replicaReadSeed) newTiKVTxn := &KVTxn{ snapshot: snapshot, - us: tidbkv.NewUnionStore(snapshot), + us: unionstore.NewUnionStore(snapshot), store: store, startTS: startTS, startTime: time.Now(), @@ -496,16 +497,16 @@ func (txn *KVTxn) LockKeys(ctx context.Context, lockCtx *tidbkv.LockCtx, keysInp } } for _, key := range keys { - valExists := tidbkv.SetKeyLockedValueExists + valExists := kv.SetKeyLockedValueExists // PointGet and BatchPointGet will return value in pessimistic lock response, the value may not exist. // For other lock modes, the locked key values always exist. if lockCtx.ReturnValues { val, _ := lockCtx.Values[string(key)] if len(val.Value) == 0 { - valExists = tidbkv.SetKeyLockedValueNotExists + valExists = kv.SetKeyLockedValueNotExists } } - memBuf.UpdateFlags(key, tidbkv.SetKeyLocked, tidbkv.DelNeedCheckExists, valExists) + memBuf.UpdateFlags(key, kv.SetKeyLocked, kv.DelNeedCheckExists, valExists) } txn.lockedCnt += len(keys) return nil @@ -603,12 +604,12 @@ func (txn *KVTxn) Reset() { } // GetUnionStore returns the UnionStore binding to this transaction. -func (txn *KVTxn) GetUnionStore() tidbkv.UnionStore { +func (txn *KVTxn) GetUnionStore() *unionstore.KVUnionStore { return txn.us } // GetMemBuffer return the MemBuffer binding to this transaction. -func (txn *KVTxn) GetMemBuffer() tidbkv.MemBuffer { +func (txn *KVTxn) GetMemBuffer() *unionstore.MemDB { return txn.us.GetMemBuffer() } diff --git a/store/tikv/unionstore/interface.go b/store/tikv/unionstore/interface.go new file mode 100644 index 0000000000000..27b0edeffc8ad --- /dev/null +++ b/store/tikv/unionstore/interface.go @@ -0,0 +1,28 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package unionstore + +import ( + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" +) + +// MemBufferIterator is an Iterator with KeyFlags related functions. +type MemBufferIterator interface { + tidbkv.Iterator + HasValue() bool + Flags() kv.KeyFlags + UpdateFlags(...kv.FlagsOp) + Handle() MemKeyHandle +} diff --git a/kv/memdb.go b/store/tikv/unionstore/memdb.go similarity index 69% rename from kv/memdb.go rename to store/tikv/unionstore/memdb.go index 5e972bf904856..01b53936e3e66 100644 --- a/kv/memdb.go +++ b/store/tikv/unionstore/memdb.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( "bytes" @@ -20,120 +20,11 @@ import ( "sync" "sync/atomic" "unsafe" -) - -const ( - flagPresumeKNE KeyFlags = 1 << iota - flagKeyLocked - flagNeedLocked - flagKeyLockedValExist - flagNeedCheckExists - flagPrewriteOnly - flagIgnoredIn2PC - - persistentFlags = flagKeyLocked | flagKeyLockedValExist - // bit 1 => red, bit 0 => black - nodeColorBit uint8 = 0x80 - nodeFlagsMask = ^nodeColorBit -) - -// KeyFlags are metadata associated with key -type KeyFlags uint8 - -// HasPresumeKeyNotExists returns whether the associated key use lazy check. -func (f KeyFlags) HasPresumeKeyNotExists() bool { - return f&flagPresumeKNE != 0 -} - -// HasLocked returns whether the associated key has acquired pessimistic lock. -func (f KeyFlags) HasLocked() bool { - return f&flagKeyLocked != 0 -} - -// HasNeedLocked return whether the key needed to be locked -func (f KeyFlags) HasNeedLocked() bool { - return f&flagNeedLocked != 0 -} - -// HasLockedValueExists returns whether the value exists when key locked. -func (f KeyFlags) HasLockedValueExists() bool { - return f&flagKeyLockedValExist != 0 -} - -// HasNeedCheckExists returns whether the key need to check existence when it has been locked. -func (f KeyFlags) HasNeedCheckExists() bool { - return f&flagNeedCheckExists != 0 -} - -// HasPrewriteOnly returns whether the key should be used in 2pc commit phase. -func (f KeyFlags) HasPrewriteOnly() bool { - return f&flagPrewriteOnly != 0 -} - -// HasIgnoredIn2PC returns whether the key will be ignored in 2pc. -func (f KeyFlags) HasIgnoredIn2PC() bool { - return f&flagIgnoredIn2PC != 0 -} -// FlagsOp describes KeyFlags modify operation. -type FlagsOp uint16 - -const ( - // SetPresumeKeyNotExists marks the existence of the associated key is checked lazily. - // Implies KeyFlags.HasNeedCheckExists() == true. - SetPresumeKeyNotExists FlagsOp = 1 << iota - // DelPresumeKeyNotExists reverts SetPresumeKeyNotExists. - DelPresumeKeyNotExists - // SetKeyLocked marks the associated key has acquired lock. - SetKeyLocked - // DelKeyLocked reverts SetKeyLocked. - DelKeyLocked - // SetNeedLocked marks the associated key need to be acquired lock. - SetNeedLocked - // DelNeedLocked reverts SetKeyNeedLocked. - DelNeedLocked - // SetKeyLockedValueExists marks the value exists when key has been locked in Transaction.LockKeys. - SetKeyLockedValueExists - // SetKeyLockedValueNotExists marks the value doesn't exists when key has been locked in Transaction.LockKeys. - SetKeyLockedValueNotExists - // DelNeedCheckExists marks the key no need to be checked in Transaction.LockKeys. - DelNeedCheckExists - // SetPrewriteOnly marks the key shouldn't be used in 2pc commit phase. - SetPrewriteOnly - // SetIgnoredIn2PC marks the key will be ignored in 2pc. - SetIgnoredIn2PC + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" ) -func applyFlagsOps(origin KeyFlags, ops ...FlagsOp) KeyFlags { - for _, op := range ops { - switch op { - case SetPresumeKeyNotExists: - origin |= flagPresumeKNE | flagNeedCheckExists - case DelPresumeKeyNotExists: - origin &= ^(flagPresumeKNE | flagNeedCheckExists) - case SetKeyLocked: - origin |= flagKeyLocked - case DelKeyLocked: - origin &= ^flagKeyLocked - case SetNeedLocked: - origin |= flagNeedLocked - case DelNeedLocked: - origin &= ^flagNeedLocked - case SetKeyLockedValueExists: - origin |= flagKeyLockedValExist - case DelNeedCheckExists: - origin &= ^flagNeedCheckExists - case SetKeyLockedValueNotExists: - origin &= ^flagKeyLockedValExist - case SetPrewriteOnly: - origin |= flagPrewriteOnly - case SetIgnoredIn2PC: - origin |= flagIgnoredIn2PC - } - } - return origin -} - var tombstone = []byte{} // IsTombstone returns whether the value is a tombstone. @@ -151,15 +42,15 @@ func (h MemKeyHandle) toAddr() memdbArenaAddr { return memdbArenaAddr{idx: uint32(h.idx), off: h.off} } -// memdb is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario. -// You can think memdb is a combination of two separate tree map, one for key => value and another for key => keyFlags. +// MemDB is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario. +// You can think MemDB is a combination of two separate tree map, one for key => value and another for key => keyFlags. // // The value map is rollbackable, that means you can use the `Staging`, `Release` and `Cleanup` API to safely modify KVs. // // The flags map is not rollbackable. There are two types of flag, persistent and non-persistent. // When discarding a newly added KV in `Cleanup`, the non-persistent flags will be cleared. // If there are persistent flags associated with key, we will keep this key in node without value. -type memdb struct { +type MemDB struct { // This RWMutex only used to ensure memdbSnapGetter.Get will not race with // concurrent memdb.Set, memdb.SetWithFlags, memdb.Delete and memdb.UpdateFlags. sync.RWMutex @@ -177,25 +68,29 @@ type memdb struct { stages []memdbCheckpoint } -func newMemDB() *memdb { - db := new(memdb) +func newMemDB() *MemDB { + db := new(MemDB) db.allocator.init() db.root = nullAddr db.stages = make([]memdbCheckpoint, 0, 2) - db.entrySizeLimit = atomic.LoadUint64(&TxnEntrySizeLimit) - db.bufferSizeLimit = atomic.LoadUint64(&TxnTotalSizeLimit) + db.entrySizeLimit = atomic.LoadUint64(&tidbkv.TxnEntrySizeLimit) + db.bufferSizeLimit = atomic.LoadUint64(&tidbkv.TxnTotalSizeLimit) return db } -func (db *memdb) Staging() StagingHandle { +// Staging create a new staging buffer inside the MemBuffer. +// Subsequent writes will be temporarily stored in this new staging buffer. +// When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer. +func (db *MemDB) Staging() tidbkv.StagingHandle { db.Lock() defer db.Unlock() db.stages = append(db.stages, db.vlog.checkpoint()) - return StagingHandle(len(db.stages)) + return tidbkv.StagingHandle(len(db.stages)) } -func (db *memdb) Release(h StagingHandle) { +// Release publish all modifications in the latest staging buffer to upper level. +func (db *MemDB) Release(h tidbkv.StagingHandle) { if int(h) != len(db.stages) { // This should never happens in production environment. // Use panic to make debug easier. @@ -213,7 +108,9 @@ func (db *memdb) Release(h StagingHandle) { db.stages = db.stages[:int(h)-1] } -func (db *memdb) Cleanup(h StagingHandle) { +// Cleanup cleanup the resources referenced by the StagingHandle. +// If the changes are not published by `Release`, they will be discarded. +func (db *MemDB) Cleanup(h tidbkv.StagingHandle) { if int(h) > len(db.stages) { return } @@ -236,7 +133,8 @@ func (db *memdb) Cleanup(h StagingHandle) { db.stages = db.stages[:int(h)-1] } -func (db *memdb) Reset() { +// Reset resets the MemBuffer to initial states. +func (db *MemDB) Reset() { db.root = nullAddr db.stages = db.stages[:0] db.dirty = false @@ -247,19 +145,24 @@ func (db *memdb) Reset() { db.allocator.reset() } -func (db *memdb) DiscardValues() { +// DiscardValues releases the memory used by all values. +// NOTE: any operation need value will panic after this function. +func (db *MemDB) DiscardValues() { db.vlogInvalid = true db.vlog.reset() } -func (db *memdb) InspectStage(handle StagingHandle, f func(Key, KeyFlags, []byte)) { +// InspectStage used to inspect the value updates in the given stage. +func (db *MemDB) InspectStage(handle tidbkv.StagingHandle, f func(tidbkv.Key, kv.KeyFlags, []byte)) { idx := int(handle) - 1 tail := db.vlog.checkpoint() head := db.stages[idx] db.vlog.inspectKVInLog(db, &head, &tail, f) } -func (db *memdb) Get(_ context.Context, key Key) ([]byte, error) { +// Get gets the value for key k from kv store. +// If corresponding kv pair does not exist, it returns nil and ErrNotExist. +func (db *MemDB) Get(_ context.Context, key tidbkv.Key) ([]byte, error) { if db.vlogInvalid { // panic for easier debugging. panic("vlog is resetted") @@ -267,23 +170,24 @@ func (db *memdb) Get(_ context.Context, key Key) ([]byte, error) { x := db.traverse(key, false) if x.isNull() { - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } if x.vptr.isNull() { // A flag only key, act as value not exists - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } return db.vlog.getValue(x.vptr), nil } -func (db *memdb) SelectValueHistory(key Key, predicate func(value []byte) bool) ([]byte, error) { +// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. +func (db *MemDB) SelectValueHistory(key tidbkv.Key, predicate func(value []byte) bool) ([]byte, error) { x := db.traverse(key, false) if x.isNull() { - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } if x.vptr.isNull() { // A flag only key, act as value not exists - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } result := db.vlog.selectValueHistory(x.vptr, func(addr memdbArenaAddr) bool { return predicate(db.vlog.getValue(addr)) @@ -294,47 +198,56 @@ func (db *memdb) SelectValueHistory(key Key, predicate func(value []byte) bool) return db.vlog.getValue(result), nil } -func (db *memdb) GetFlags(key Key) (KeyFlags, error) { +// GetFlags returns the latest flags associated with key. +func (db *MemDB) GetFlags(key tidbkv.Key) (kv.KeyFlags, error) { x := db.traverse(key, false) if x.isNull() { - return 0, ErrNotExist + return 0, tidbkv.ErrNotExist } return x.getKeyFlags(), nil } -func (db *memdb) UpdateFlags(key Key, ops ...FlagsOp) { +// UpdateFlags update the flags associated with key. +func (db *MemDB) UpdateFlags(key tidbkv.Key, ops ...kv.FlagsOp) { err := db.set(key, nil, ops...) _ = err // set without value will never fail } -func (db *memdb) Set(key Key, value []byte) error { +// Set sets the value for key k as v into kv store. +// v must NOT be nil or empty, otherwise it returns ErrCannotSetNilValue. +func (db *MemDB) Set(key tidbkv.Key, value []byte) error { if len(value) == 0 { - return ErrCannotSetNilValue + return tidbkv.ErrCannotSetNilValue } return db.set(key, value) } -func (db *memdb) SetWithFlags(key Key, value []byte, ops ...FlagsOp) error { +// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags. +func (db *MemDB) SetWithFlags(key tidbkv.Key, value []byte, ops ...kv.FlagsOp) error { if len(value) == 0 { - return ErrCannotSetNilValue + return tidbkv.ErrCannotSetNilValue } return db.set(key, value, ops...) } -func (db *memdb) Delete(key Key) error { +// Delete removes the entry for key k from kv store. +func (db *MemDB) Delete(key tidbkv.Key) error { return db.set(key, tombstone) } -func (db *memdb) DeleteWithFlags(key Key, ops ...FlagsOp) error { +// DeleteWithFlags delete key with the given KeyFlags +func (db *MemDB) DeleteWithFlags(key tidbkv.Key, ops ...kv.FlagsOp) error { return db.set(key, tombstone, ops...) } -func (db *memdb) GetKeyByHandle(handle MemKeyHandle) []byte { +// GetKeyByHandle returns key by handle. +func (db *MemDB) GetKeyByHandle(handle MemKeyHandle) []byte { x := db.getNode(handle.toAddr()) return x.getKey() } -func (db *memdb) GetValueByHandle(handle MemKeyHandle) ([]byte, bool) { +// GetValueByHandle returns value by handle. +func (db *MemDB) GetValueByHandle(handle MemKeyHandle) ([]byte, bool) { if db.vlogInvalid { return nil, false } @@ -345,19 +258,22 @@ func (db *memdb) GetValueByHandle(handle MemKeyHandle) ([]byte, bool) { return db.vlog.getValue(x.vptr), true } -func (db *memdb) Len() int { +// Len returns the number of entries in the DB. +func (db *MemDB) Len() int { return db.count } -func (db *memdb) Size() int { +// Size returns sum of keys and values length. +func (db *MemDB) Size() int { return db.size } -func (db *memdb) Dirty() bool { +// Dirty returns whether the root staging buffer is updated. +func (db *MemDB) Dirty() bool { return db.dirty } -func (db *memdb) set(key Key, value []byte, ops ...FlagsOp) error { +func (db *MemDB) set(key tidbkv.Key, value []byte, ops ...kv.FlagsOp) error { if db.vlogInvalid { // panic for easier debugging. panic("vlog is resetted") @@ -365,7 +281,7 @@ func (db *memdb) set(key Key, value []byte, ops ...FlagsOp) error { if value != nil { if size := uint64(len(key) + len(value)); size > db.entrySizeLimit { - return ErrEntryTooLarge.GenWithStackByArgs(db.entrySizeLimit, size) + return tidbkv.ErrEntryTooLarge.GenWithStackByArgs(db.entrySizeLimit, size) } } @@ -378,8 +294,8 @@ func (db *memdb) set(key Key, value []byte, ops ...FlagsOp) error { x := db.traverse(key, true) if len(ops) != 0 { - flags := applyFlagsOps(x.getKeyFlags(), ops...) - if flags&persistentFlags != 0 { + flags := kv.ApplyFlagsOps(x.getKeyFlags(), ops...) + if flags.AndPersistent() != 0 { db.dirty = true } x.setKeyFlags(flags) @@ -391,12 +307,12 @@ func (db *memdb) set(key Key, value []byte, ops ...FlagsOp) error { db.setValue(x, value) if uint64(db.Size()) > db.bufferSizeLimit { - return ErrTxnTooLarge.GenWithStackByArgs(db.Size()) + return tidbkv.ErrTxnTooLarge.GenWithStackByArgs(db.Size()) } return nil } -func (db *memdb) setValue(x memdbNodeAddr, value []byte) { +func (db *MemDB) setValue(x memdbNodeAddr, value []byte) { var activeCp *memdbCheckpoint if len(db.stages) > 0 { activeCp = &db.stages[len(db.stages)-1] @@ -421,7 +337,7 @@ func (db *memdb) setValue(x memdbNodeAddr, value []byte) { // traverse search for and if not found and insert is true, will add a new node in. // Returns a pointer to the new node, or the node found. -func (db *memdb) traverse(key Key, insert bool) memdbNodeAddr { +func (db *MemDB) traverse(key tidbkv.Key, insert bool) memdbNodeAddr { x := db.getRoot() y := memdbNodeAddr{nil, nullAddr} found := false @@ -549,7 +465,7 @@ func (db *memdb) traverse(key Key, insert bool) memdbNodeAddr { // We assume that neither X nor Y is NULL // -func (db *memdb) leftRotate(x memdbNodeAddr) { +func (db *MemDB) leftRotate(x memdbNodeAddr) { y := x.getRight(db) // Turn Y's left subtree into X's right subtree (move B) @@ -583,7 +499,7 @@ func (db *memdb) leftRotate(x memdbNodeAddr) { x.up = y.addr } -func (db *memdb) rightRotate(y memdbNodeAddr) { +func (db *MemDB) rightRotate(y memdbNodeAddr) { x := y.getLeft(db) // Turn X's right subtree into Y's left subtree (move B) @@ -617,7 +533,7 @@ func (db *memdb) rightRotate(y memdbNodeAddr) { y.up = x.addr } -func (db *memdb) deleteNode(z memdbNodeAddr) { +func (db *MemDB) deleteNode(z memdbNodeAddr) { var x, y memdbNodeAddr db.count-- @@ -663,7 +579,7 @@ func (db *memdb) deleteNode(z memdbNodeAddr) { db.allocator.freeNode(z.addr) } -func (db *memdb) replaceNode(old memdbNodeAddr, new memdbNodeAddr) { +func (db *MemDB) replaceNode(old memdbNodeAddr, new memdbNodeAddr) { if !old.up.isNull() { oldUp := old.getUp(db) if old.addr == oldUp.left { @@ -691,7 +607,7 @@ func (db *memdb) replaceNode(old memdbNodeAddr, new memdbNodeAddr) { } } -func (db *memdb) deleteNodeFix(x memdbNodeAddr) { +func (db *MemDB) deleteNodeFix(x memdbNodeAddr) { for x.addr != db.root && x.isBlack() { xUp := x.getUp(db) if x.addr == xUp.left { @@ -761,7 +677,7 @@ func (db *memdb) deleteNodeFix(x memdbNodeAddr) { x.setBlack() } -func (db *memdb) successor(x memdbNodeAddr) (y memdbNodeAddr) { +func (db *MemDB) successor(x memdbNodeAddr) (y memdbNodeAddr) { if !x.right.isNull() { // If right is not NULL then go right one and // then keep going left until we find a node with @@ -786,7 +702,7 @@ func (db *memdb) successor(x memdbNodeAddr) (y memdbNodeAddr) { return y } -func (db *memdb) predecessor(x memdbNodeAddr) (y memdbNodeAddr) { +func (db *MemDB) predecessor(x memdbNodeAddr) (y memdbNodeAddr) { if !x.left.isNull() { // If left is not NULL then go left one and // then keep going right until we find a node with @@ -811,15 +727,15 @@ func (db *memdb) predecessor(x memdbNodeAddr) (y memdbNodeAddr) { return y } -func (db *memdb) getNode(x memdbArenaAddr) memdbNodeAddr { +func (db *MemDB) getNode(x memdbArenaAddr) memdbNodeAddr { return memdbNodeAddr{db.allocator.getNode(x), x} } -func (db *memdb) getRoot() memdbNodeAddr { +func (db *MemDB) getRoot() memdbNodeAddr { return db.getNode(db.root) } -func (db *memdb) allocNode(key Key) memdbNodeAddr { +func (db *MemDB) allocNode(key tidbkv.Key) memdbNodeAddr { db.size += len(key) db.count++ x, xn := db.allocator.allocNode(key) @@ -835,15 +751,15 @@ func (a *memdbNodeAddr) isNull() bool { return a.addr.isNull() } -func (a memdbNodeAddr) getUp(db *memdb) memdbNodeAddr { +func (a memdbNodeAddr) getUp(db *MemDB) memdbNodeAddr { return db.getNode(a.up) } -func (a memdbNodeAddr) getLeft(db *memdb) memdbNodeAddr { +func (a memdbNodeAddr) getLeft(db *MemDB) memdbNodeAddr { return db.getNode(a.left) } -func (a memdbNodeAddr) getRight(db *memdb) memdbNodeAddr { +func (a memdbNodeAddr) getRight(db *MemDB) memdbNodeAddr { return db.getNode(a.right) } @@ -872,7 +788,7 @@ func (n *memdbNode) setBlack() { n.flags &= ^nodeColorBit } -func (n *memdbNode) getKey() Key { +func (n *memdbNode) getKey() tidbkv.Key { var ret []byte hdr := (*reflect.SliceHeader)(unsafe.Pointer(&ret)) hdr.Data = uintptr(unsafe.Pointer(&n.flags)) + 1 @@ -881,10 +797,16 @@ func (n *memdbNode) getKey() Key { return ret } -func (n *memdbNode) getKeyFlags() KeyFlags { - return KeyFlags(n.flags & nodeFlagsMask) +const ( + // bit 1 => red, bit 0 => black + nodeColorBit uint8 = 0x80 + nodeFlagsMask = ^nodeColorBit +) + +func (n *memdbNode) getKeyFlags() kv.KeyFlags { + return kv.KeyFlags(n.flags & nodeFlagsMask) } -func (n *memdbNode) setKeyFlags(f KeyFlags) { +func (n *memdbNode) setKeyFlags(f kv.KeyFlags) { n.flags = (^nodeFlagsMask & n.flags) | uint8(f) } diff --git a/kv/memdb_arena.go b/store/tikv/unionstore/memdb_arena.go similarity index 95% rename from kv/memdb_arena.go rename to store/tikv/unionstore/memdb_arena.go index 68ed49a1e1afa..46f1846b99a80 100644 --- a/kv/memdb_arena.go +++ b/store/tikv/unionstore/memdb_arena.go @@ -11,12 +11,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( "encoding/binary" "math" "unsafe" + + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" ) const ( @@ -190,7 +193,7 @@ func (a *nodeAllocator) getNode(addr memdbArenaAddr) *memdbNode { return (*memdbNode)(unsafe.Pointer(&a.blocks[addr.idx].buf[addr.off])) } -func (a *nodeAllocator) allocNode(key Key) (memdbArenaAddr, *memdbNode) { +func (a *nodeAllocator) allocNode(key tidbkv.Key) (memdbArenaAddr, *memdbNode) { nodeSize := 8*4 + 2 + 1 + len(key) addr, mem := a.alloc(nodeSize, true) n := (*memdbNode)(unsafe.Pointer(&mem[0])) @@ -297,7 +300,7 @@ func (l *memdbVlog) selectValueHistory(addr memdbArenaAddr, predicate func(memdb return nullAddr } -func (l *memdbVlog) revertToCheckpoint(db *memdb, cp *memdbCheckpoint) { +func (l *memdbVlog) revertToCheckpoint(db *MemDB, cp *memdbCheckpoint) { cursor := l.checkpoint() for !cp.isSamePosition(&cursor) { hdrOff := cursor.offsetInBlock - memdbVlogHdrSize @@ -311,7 +314,7 @@ func (l *memdbVlog) revertToCheckpoint(db *memdb, cp *memdbCheckpoint) { // oldValue.isNull() == true means this is a newly added value. if hdr.oldValue.isNull() { // If there are no flags associated with this key, we need to delete this node. - keptFlags := node.getKeyFlags() & persistentFlags + keptFlags := node.getKeyFlags().AndPersistent() if keptFlags == 0 { db.deleteNode(node) } else { @@ -326,7 +329,7 @@ func (l *memdbVlog) revertToCheckpoint(db *memdb, cp *memdbCheckpoint) { } } -func (l *memdbVlog) inspectKVInLog(db *memdb, head, tail *memdbCheckpoint, f func(Key, KeyFlags, []byte)) { +func (l *memdbVlog) inspectKVInLog(db *MemDB, head, tail *memdbCheckpoint, f func(tidbkv.Key, kv.KeyFlags, []byte)) { cursor := *tail for !head.isSamePosition(&cursor) { cursorAddr := memdbArenaAddr{idx: uint32(cursor.blocks - 1), off: uint32(cursor.offsetInBlock)} diff --git a/kv/memdb_bench_test.go b/store/tikv/unionstore/memdb_bench_test.go similarity index 95% rename from kv/memdb_bench_test.go rename to store/tikv/unionstore/memdb_bench_test.go index 363ca95d17d92..112ea9b0f2a7a 100644 --- a/kv/memdb_bench_test.go +++ b/store/tikv/unionstore/memdb_bench_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( "context" @@ -26,7 +26,7 @@ const ( valueSize = 128 ) -func newMemDBForBench() *memdb { +func newMemDBForBench() *MemDB { db := newMemDB() db.bufferSizeLimit = math.MaxUint64 db.entrySizeLimit = math.MaxUint64 @@ -152,7 +152,7 @@ func shuffle(slc [][]byte) { slc[r], slc[i] = slc[i], slc[r] } } -func benchmarkSetGet(b *testing.B, buffer MemBuffer, data [][]byte) { +func benchmarkSetGet(b *testing.B, buffer *MemDB, data [][]byte) { b.ResetTimer() for i := 0; i < b.N; i++ { for _, k := range data { @@ -164,7 +164,7 @@ func benchmarkSetGet(b *testing.B, buffer MemBuffer, data [][]byte) { } } -func benchIterator(b *testing.B, buffer MemBuffer) { +func benchIterator(b *testing.B, buffer *MemDB) { for k := 0; k < opCnt; k++ { buffer.Set(encodeInt(k), encodeInt(k)) } diff --git a/kv/memdb_iterator.go b/store/tikv/unionstore/memdb_iterator.go similarity index 69% rename from kv/memdb_iterator.go rename to store/tikv/unionstore/memdb_iterator.go index b89c8197331e1..60fcf881b8826 100644 --- a/kv/memdb_iterator.go +++ b/store/tikv/unionstore/memdb_iterator.go @@ -11,20 +11,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore -import "bytes" +import ( + "bytes" + + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" +) type memdbIterator struct { - db *memdb + db *MemDB curr memdbNodeAddr - start Key - end Key + start tidbkv.Key + end tidbkv.Key reverse bool includeFlags bool } -func (db *memdb) Iter(k Key, upperBound Key) (Iterator, error) { +// Iter creates an Iterator positioned on the first entry that k <= entry's key. +// If such entry is not found, it returns an invalid Iterator with no error. +// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. +// The Iterator must be Closed after use. +func (db *MemDB) Iter(k tidbkv.Key, upperBound tidbkv.Key) (tidbkv.Iterator, error) { i := &memdbIterator{ db: db, start: k, @@ -34,7 +43,11 @@ func (db *memdb) Iter(k Key, upperBound Key) (Iterator, error) { return i, nil } -func (db *memdb) IterReverse(k Key) (Iterator, error) { +// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. +// The returned iterator will iterate from greater key to smaller key. +// If k is nil, the returned iterator will be positioned at the last key. +// TODO: Add lower bound limit +func (db *MemDB) IterReverse(k tidbkv.Key) (tidbkv.Iterator, error) { i := &memdbIterator{ db: db, end: k, @@ -44,7 +57,8 @@ func (db *memdb) IterReverse(k Key) (Iterator, error) { return i, nil } -func (db *memdb) IterWithFlags(k Key, upperBound Key) MemBufferIterator { +// IterWithFlags returns a MemBufferIterator. +func (db *MemDB) IterWithFlags(k tidbkv.Key, upperBound tidbkv.Key) MemBufferIterator { i := &memdbIterator{ db: db, start: k, @@ -55,7 +69,8 @@ func (db *memdb) IterWithFlags(k Key, upperBound Key) MemBufferIterator { return i } -func (db *memdb) IterReverseWithFlags(k Key) MemBufferIterator { +// IterReverseWithFlags returns a reversed MemBufferIterator. +func (db *MemDB) IterReverseWithFlags(k tidbkv.Key) MemBufferIterator { i := &memdbIterator{ db: db, end: k, @@ -94,13 +109,13 @@ func (i *memdbIterator) Valid() bool { return !i.curr.isNull() } -func (i *memdbIterator) Flags() KeyFlags { +func (i *memdbIterator) Flags() kv.KeyFlags { return i.curr.getKeyFlags() } -func (i *memdbIterator) UpdateFlags(ops ...FlagsOp) { +func (i *memdbIterator) UpdateFlags(ops ...kv.FlagsOp) { origin := i.curr.getKeyFlags() - n := applyFlagsOps(origin, ops...) + n := kv.ApplyFlagsOps(origin, ops...) i.curr.setKeyFlags(n) } @@ -108,7 +123,7 @@ func (i *memdbIterator) HasValue() bool { return !i.isFlagsOnly() } -func (i *memdbIterator) Key() Key { +func (i *memdbIterator) Key() tidbkv.Key { return i.curr.getKey() } @@ -165,7 +180,7 @@ func (i *memdbIterator) seekToLast() { i.curr = y } -func (i *memdbIterator) seek(key Key) { +func (i *memdbIterator) seek(key tidbkv.Key) { y := memdbNodeAddr{nil, nullAddr} x := i.db.getNode(i.db.root) diff --git a/kv/memdb_norace_test.go b/store/tikv/unionstore/memdb_norace_test.go similarity index 97% rename from kv/memdb_norace_test.go rename to store/tikv/unionstore/memdb_norace_test.go index c7bb3f422229c..865be38a15947 100644 --- a/kv/memdb_norace_test.go +++ b/store/tikv/unionstore/memdb_norace_test.go @@ -13,7 +13,7 @@ // +build !race -package kv +package unionstore import ( "encoding/binary" @@ -69,7 +69,7 @@ func (s testMemDBSuite) TestRandomDerive(c *C) { s.testRandomDeriveRecur(c, db, golden, 0) } -func (s testMemDBSuite) testRandomDeriveRecur(c *C, db *memdb, golden *leveldb.DB, depth int) [][2][]byte { +func (s testMemDBSuite) testRandomDeriveRecur(c *C, db *MemDB, golden *leveldb.DB, depth int) [][2][]byte { var keys [][]byte if op := rand.Float64(); op < 0.33 { start, end := rand.Intn(512), rand.Intn(512)+512 diff --git a/kv/memdb_snapshot.go b/store/tikv/unionstore/memdb_snapshot.go similarity index 75% rename from kv/memdb_snapshot.go rename to store/tikv/unionstore/memdb_snapshot.go index 96ae69ad12431..3303dab43624d 100644 --- a/kv/memdb_snapshot.go +++ b/store/tikv/unionstore/memdb_snapshot.go @@ -11,18 +11,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore -import "context" +import ( + "context" -func (db *memdb) SnapshotGetter() Getter { + tidbkv "github.com/pingcap/tidb/kv" +) + +// SnapshotGetter returns a Getter for a snapshot of MemBuffer. +func (db *MemDB) SnapshotGetter() tidbkv.Getter { return &memdbSnapGetter{ db: db, cp: db.getSnapshot(), } } -func (db *memdb) SnapshotIter(start, end Key) Iterator { +// SnapshotIter returns a Iterator for a snapshot of MemBuffer. +func (db *MemDB) SnapshotIter(start, end tidbkv.Key) tidbkv.Iterator { it := &memdbSnapIter{ memdbIterator: &memdbIterator{ db: db, @@ -35,7 +41,7 @@ func (db *memdb) SnapshotIter(start, end Key) Iterator { return it } -func (db *memdb) getSnapshot() memdbCheckpoint { +func (db *MemDB) getSnapshot() memdbCheckpoint { if len(db.stages) > 0 { return db.stages[0] } @@ -43,22 +49,22 @@ func (db *memdb) getSnapshot() memdbCheckpoint { } type memdbSnapGetter struct { - db *memdb + db *MemDB cp memdbCheckpoint } -func (snap *memdbSnapGetter) Get(_ context.Context, key Key) ([]byte, error) { +func (snap *memdbSnapGetter) Get(_ context.Context, key tidbkv.Key) ([]byte, error) { x := snap.db.traverse(key, false) if x.isNull() { - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } if x.vptr.isNull() { // A flag only key, act as value not exists - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } v, ok := snap.db.vlog.getSnapshotValue(x.vptr, &snap.cp) if !ok { - return nil, ErrNotExist + return nil, tidbkv.ErrNotExist } return v, nil } diff --git a/kv/memdb_test.go b/store/tikv/unionstore/memdb_test.go similarity index 95% rename from kv/memdb_test.go rename to store/tikv/unionstore/memdb_test.go index d0ad12c426095..dc0694972bdb9 100644 --- a/kv/memdb_test.go +++ b/store/tikv/unionstore/memdb_test.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( "context" @@ -24,9 +24,16 @@ import ( . "github.com/pingcap/check" leveldb "github.com/pingcap/goleveldb/leveldb/memdb" + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/util/testleak" ) +type Key = tidbkv.Key +type KeyFlags = kv.KeyFlags +type StagingHandle = tidbkv.StagingHandle +type Iterator = tidbkv.Iterator + func init() { testMode = true } @@ -44,7 +51,7 @@ var ( type testMemDBSuite struct{} // DeleteKey is used in test to verify the `deleteNode` used in `vlog.revertToCheckpoint`. -func (db *memdb) DeleteKey(key []byte) { +func (db *MemDB) DeleteKey(key []byte) { x := db.traverse(key, false) if x.isNull() { return @@ -451,14 +458,14 @@ func (s *testMemDBSuite) TestDirty(c *C) { // persistent flags will make memdb dirty. db = newMemDB() h = db.Staging() - db.SetWithFlags([]byte{1}, []byte{1}, SetKeyLocked) + db.SetWithFlags([]byte{1}, []byte{1}, kv.SetKeyLocked) db.Cleanup(h) c.Assert(db.Dirty(), IsTrue) // non-persistent flags will not make memdb dirty. db = newMemDB() h = db.Staging() - db.SetWithFlags([]byte{1}, []byte{1}, SetPresumeKeyNotExists) + db.SetWithFlags([]byte{1}, []byte{1}, kv.SetPresumeKeyNotExists) db.Cleanup(h) c.Assert(db.Dirty(), IsFalse) } @@ -471,9 +478,9 @@ func (s *testMemDBSuite) TestFlags(c *C) { var buf [4]byte binary.BigEndian.PutUint32(buf[:], i) if i%2 == 0 { - db.SetWithFlags(buf[:], buf[:], SetPresumeKeyNotExists, SetKeyLocked) + db.SetWithFlags(buf[:], buf[:], kv.SetPresumeKeyNotExists, kv.SetKeyLocked) } else { - db.SetWithFlags(buf[:], buf[:], SetPresumeKeyNotExists) + db.SetWithFlags(buf[:], buf[:], kv.SetPresumeKeyNotExists) } } db.Cleanup(h) @@ -511,7 +518,7 @@ func (s *testMemDBSuite) TestFlags(c *C) { for i := uint32(0); i < cnt; i++ { var buf [4]byte binary.BigEndian.PutUint32(buf[:], i) - db.UpdateFlags(buf[:], DelKeyLocked) + db.UpdateFlags(buf[:], kv.DelKeyLocked) } for i := uint32(0); i < cnt; i++ { var buf [4]byte @@ -526,7 +533,7 @@ func (s *testMemDBSuite) TestFlags(c *C) { } } -func (s *testMemDBSuite) checkConsist(c *C, p1 *memdb, p2 *leveldb.DB) { +func (s *testMemDBSuite) checkConsist(c *C, p1 *MemDB, p2 *leveldb.DB) { c.Assert(p1.Len(), Equals, p2.Len()) c.Assert(p1.Size(), Equals, p2.Size()) @@ -565,14 +572,14 @@ func (s *testMemDBSuite) checkConsist(c *C, p1 *memdb, p2 *leveldb.DB) { } } -func (s *testMemDBSuite) fillDB(cnt int) *memdb { +func (s *testMemDBSuite) fillDB(cnt int) *MemDB { db := newMemDB() h := s.deriveAndFill(0, cnt, 0, db) db.Release(h) return db } -func (s *testMemDBSuite) deriveAndFill(start, end, valueBase int, db *memdb) StagingHandle { +func (s *testMemDBSuite) deriveAndFill(start, end, valueBase int, db *MemDB) StagingHandle { h := db.Staging() var kbuf, vbuf [4]byte for i := start; i < end; i++ { @@ -590,11 +597,11 @@ const ( ) type testKVSuite struct { - bs []MemBuffer + bs []*MemDB } func (s *testKVSuite) SetUpSuite(c *C) { - s.bs = make([]MemBuffer, 1) + s.bs = make([]*MemDB, 1) s.bs[0] = newMemDB() } @@ -602,7 +609,7 @@ func (s *testKVSuite) ResetMembuffers() { s.bs[0] = newMemDB() } -func insertData(c *C, buffer MemBuffer) { +func insertData(c *C, buffer *MemDB) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) err := buffer.Set(val, val) @@ -625,7 +632,7 @@ func valToStr(c *C, iter Iterator) string { return string(val) } -func checkNewIterator(c *C, buffer MemBuffer) { +func checkNewIterator(c *C, buffer *MemDB) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) iter, err := buffer.Iter(val, nil) @@ -670,7 +677,7 @@ func checkNewIterator(c *C, buffer MemBuffer) { iter.Close() } -func mustGet(c *C, buffer MemBuffer) { +func mustGet(c *C, buffer *MemDB) { for i := startIndex; i < testCount; i++ { s := encodeInt(i * indexStep) val, err := buffer.Get(context.TODO(), s) @@ -710,7 +717,7 @@ func (s *testKVSuite) TestIterNextUntil(c *C) { iter, err := buffer.Iter(nil, nil) c.Assert(err, IsNil) - err = NextUntil(iter, func(k Key) bool { + err = tidbkv.NextUntil(iter, func(k Key) bool { return false }) c.Assert(err, IsNil) @@ -829,7 +836,7 @@ func (s *testKVSuite) TestBufferBatchGetter(c *C) { buffer.Set(ka, []byte("a2")) buffer.Delete(kb) - batchGetter := NewBufferBatchGetter(buffer, middle, snap) + batchGetter := tidbkv.NewBufferBatchGetter(buffer, middle, snap) result, err := batchGetter.BatchGet(context.Background(), []Key{ka, kb, kc, kd}) c.Assert(err, IsNil) c.Assert(len(result), Equals, 3) diff --git a/store/tikv/unionstore/mock.go b/store/tikv/unionstore/mock.go new file mode 100644 index 0000000000000..fa1e8b96ca233 --- /dev/null +++ b/store/tikv/unionstore/mock.go @@ -0,0 +1,58 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package unionstore + +import ( + "context" + + "github.com/pingcap/tidb/kv" +) + +type mockSnapshot struct { + store *MemDB +} + +func (s *mockSnapshot) Get(ctx context.Context, k kv.Key) ([]byte, error) { + return s.store.Get(ctx, k) +} + +func (s *mockSnapshot) SetPriority(priority int) { + +} + +func (s *mockSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { + m := make(map[string][]byte, len(keys)) + for _, k := range keys { + v, err := s.store.Get(ctx, k) + if kv.IsErrNotFound(err) { + continue + } + if err != nil { + return nil, err + } + m[string(k)] = v + } + return m, nil +} + +func (s *mockSnapshot) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) { + return s.store.Iter(k, upperBound) +} + +func (s *mockSnapshot) IterReverse(k kv.Key) (kv.Iterator, error) { + return s.store.IterReverse(k) +} + +func (s *mockSnapshot) SetOption(opt int, val interface{}) {} +func (s *mockSnapshot) DelOption(opt int) {} diff --git a/kv/union_iter.go b/store/tikv/unionstore/union_iter.go similarity index 93% rename from kv/union_iter.go rename to store/tikv/unionstore/union_iter.go index 4134d69cdc95f..e653ad3ad2ff3 100644 --- a/kv/union_iter.go +++ b/store/tikv/unionstore/union_iter.go @@ -11,17 +11,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( + tidbkv "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) // UnionIter is the iterator on an UnionStore. type UnionIter struct { - dirtyIt Iterator - snapshotIt Iterator + dirtyIt tidbkv.Iterator + snapshotIt tidbkv.Iterator dirtyValid bool snapshotValid bool @@ -32,7 +33,7 @@ type UnionIter struct { } // NewUnionIter returns a union iterator for BufferStore. -func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIter, error) { +func NewUnionIter(dirtyIt tidbkv.Iterator, snapshotIt tidbkv.Iterator, reverse bool) (*UnionIter, error) { it := &UnionIter{ dirtyIt: dirtyIt, snapshotIt: snapshotIt, @@ -161,7 +162,7 @@ func (iter *UnionIter) Value() []byte { } // Key implements the Iterator Key interface. -func (iter *UnionIter) Key() Key { +func (iter *UnionIter) Key() tidbkv.Key { if !iter.curIsDirty { return iter.snapshotIt.Key() } diff --git a/store/tikv/unionstore/union_store.go b/store/tikv/unionstore/union_store.go new file mode 100644 index 0000000000000..00cf6297548b9 --- /dev/null +++ b/store/tikv/unionstore/union_store.go @@ -0,0 +1,120 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package unionstore + +import ( + "context" + + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/kv" +) + +// KVUnionStore is an in-memory Store which contains a buffer for write and a +// snapshot for read. +type KVUnionStore struct { + memBuffer *MemDB + snapshot tidbkv.Snapshot + opts options +} + +// NewUnionStore builds a new unionStore. +func NewUnionStore(snapshot tidbkv.Snapshot) *KVUnionStore { + return &KVUnionStore{ + snapshot: snapshot, + memBuffer: newMemDB(), + opts: make(map[int]interface{}), + } +} + +// GetMemBuffer return the MemBuffer binding to this unionStore. +func (us *KVUnionStore) GetMemBuffer() *MemDB { + return us.memBuffer +} + +// Get implements the Retriever interface. +func (us *KVUnionStore) Get(ctx context.Context, k tidbkv.Key) ([]byte, error) { + v, err := us.memBuffer.Get(ctx, k) + if tidbkv.IsErrNotFound(err) { + v, err = us.snapshot.Get(ctx, k) + } + if err != nil { + return v, err + } + if len(v) == 0 { + return nil, tidbkv.ErrNotExist + } + return v, nil +} + +// Iter implements the Retriever interface. +func (us *KVUnionStore) Iter(k tidbkv.Key, upperBound tidbkv.Key) (tidbkv.Iterator, error) { + bufferIt, err := us.memBuffer.Iter(k, upperBound) + if err != nil { + return nil, err + } + retrieverIt, err := us.snapshot.Iter(k, upperBound) + if err != nil { + return nil, err + } + return NewUnionIter(bufferIt, retrieverIt, false) +} + +// IterReverse implements the Retriever interface. +func (us *KVUnionStore) IterReverse(k tidbkv.Key) (tidbkv.Iterator, error) { + bufferIt, err := us.memBuffer.IterReverse(k) + if err != nil { + return nil, err + } + retrieverIt, err := us.snapshot.IterReverse(k) + if err != nil { + return nil, err + } + return NewUnionIter(bufferIt, retrieverIt, true) +} + +// HasPresumeKeyNotExists gets the key exist error info for the lazy check. +func (us *KVUnionStore) HasPresumeKeyNotExists(k tidbkv.Key) bool { + flags, err := us.memBuffer.GetFlags(k) + if err != nil { + return false + } + return flags.HasPresumeKeyNotExists() +} + +// UnmarkPresumeKeyNotExists deletes the key exist error info for the lazy check. +func (us *KVUnionStore) UnmarkPresumeKeyNotExists(k tidbkv.Key) { + us.memBuffer.UpdateFlags(k, kv.DelPresumeKeyNotExists) +} + +// SetOption implements the unionStore SetOption interface. +func (us *KVUnionStore) SetOption(opt int, val interface{}) { + us.opts[opt] = val +} + +// DelOption implements the unionStore DelOption interface. +func (us *KVUnionStore) DelOption(opt int) { + delete(us.opts, opt) +} + +// GetOption implements the unionStore GetOption interface. +func (us *KVUnionStore) GetOption(opt int) interface{} { + return us.opts[opt] +} + +type options map[int]interface{} + +func (opts options) Get(opt int) (interface{}, bool) { + v, ok := opts[opt] + return v, ok +} diff --git a/kv/union_store_test.go b/store/tikv/unionstore/union_store_test.go similarity index 95% rename from kv/union_store_test.go rename to store/tikv/unionstore/union_store_test.go index 372f75eafd958..a87945574c3b6 100644 --- a/kv/union_store_test.go +++ b/store/tikv/unionstore/union_store_test.go @@ -11,20 +11,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -package kv +package unionstore import ( "context" . "github.com/pingcap/check" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/util/testleak" ) var _ = Suite(&testUnionStoreSuite{}) type testUnionStoreSuite struct { - store MemBuffer - us UnionStore + store *MemDB + us *KVUnionStore } func (s *testUnionStoreSuite) SetUpTest(c *C) { @@ -55,7 +56,7 @@ func (s *testUnionStoreSuite) TestDelete(c *C) { err = s.us.GetMemBuffer().Delete([]byte("1")) c.Assert(err, IsNil) _, err = s.us.Get(context.TODO(), []byte("1")) - c.Assert(IsErrNotFound(err), IsTrue) + c.Assert(kv.IsErrNotFound(err), IsTrue) err = s.us.GetMemBuffer().Set([]byte("1"), []byte("2")) c.Assert(err, IsNil) @@ -124,7 +125,7 @@ func (s *testUnionStoreSuite) TestIterReverse(c *C) { checkIterator(c, iter, [][]byte{[]byte("2"), []byte("0")}, [][]byte{[]byte("2"), []byte("0")}) } -func checkIterator(c *C, iter Iterator, keys [][]byte, values [][]byte) { +func checkIterator(c *C, iter kv.Iterator, keys [][]byte, values [][]byte) { defer iter.Close() c.Assert(len(keys), Equals, len(values)) for i, k := range keys { diff --git a/table/tables/index.go b/table/tables/index.go index 744418d7df665..4567a9efb83d0 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + tikvstore "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -207,7 +208,7 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue } if err != nil || len(value) == 0 { if sctx.GetSessionVars().LazyCheckKeyNotExists() && err != nil { - err = us.GetMemBuffer().SetWithFlags(key, idxVal, kv.SetPresumeKeyNotExists) + err = us.GetMemBuffer().SetWithFlags(key, idxVal, tikvstore.SetPresumeKeyNotExists) } else { err = us.GetMemBuffer().Set(key, idxVal) } @@ -228,7 +229,7 @@ func (c *index) Delete(sc *stmtctx.StatementContext, us kv.UnionStore, indexedVa return err } if distinct { - err = us.GetMemBuffer().DeleteWithFlags(key, kv.SetNeedLocked) + err = us.GetMemBuffer().DeleteWithFlags(key, tikvstore.SetNeedLocked) } else { err = us.GetMemBuffer().Delete(key) } diff --git a/table/tables/tables.go b/table/tables/tables.go index ee1a526a183ea..ac09dbb35d0bd 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/sessionctx/binloginfo" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" + tikvstore "github.com/pingcap/tidb/store/tikv/kv" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -754,7 +755,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . } if setPresume { - err = memBuffer.SetWithFlags(key, value, kv.SetPresumeKeyNotExists) + err = memBuffer.SetWithFlags(key, value, tikvstore.SetPresumeKeyNotExists) } else { err = memBuffer.Set(key, value) } diff --git a/types/datum.go b/types/datum.go index d6542f7507631..d8965b1b4bf44 100644 --- a/types/datum.go +++ b/types/datum.go @@ -252,6 +252,7 @@ func (d *Datum) GetMysqlBit() BinaryLiteral { func (d *Datum) SetBinaryLiteral(b BinaryLiteral) { d.k = KindBinaryLiteral d.b = b + d.collation = charset.CollationBin } // SetMysqlBit sets MysqlBit value diff --git a/util/admin/admin_integration_test.go b/util/admin/admin_integration_test.go index 586f865b26387..771b35f251333 100644 --- a/util/admin/admin_integration_test.go +++ b/util/admin/admin_integration_test.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/util/testkit" @@ -109,7 +110,7 @@ func (s *testAdminSuite) TestAdminCheckTableClusterIndex(c *C) { tk.MustExec("create database admin_check_table_clustered_index;") tk.MustExec("use admin_check_table_clustered_index;") - tk.Se.GetSessionVars().EnableClusteredIndex = true + tk.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn tk.MustExec("create table t (a bigint, b varchar(255), c int, primary key (a, b), index idx_0(a, b), index idx_1(b, c));") tk.MustExec("insert into t values (1, '1', 1);") diff --git a/util/ranger/ranger_test.go b/util/ranger/ranger_test.go index 9ceaa99135380..4038e197f8721 100644 --- a/util/ranger/ranger_test.go +++ b/util/ranger/ranger_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" @@ -1219,7 +1220,7 @@ func (s *testRangerSuite) TestIndexRangeElimininatedProjection(c *C) { testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") testKit.MustExec("drop table if exists t") - testKit.Se.GetSessionVars().EnableClusteredIndex = false + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly testKit.MustExec("create table t(a int not null, b int not null, primary key(a,b))") testKit.MustExec("insert into t values(1,2)") testKit.MustExec("analyze table t") @@ -1339,7 +1340,7 @@ func (s *testRangerSuite) TestCompIndexMultiColDNF1(c *C) { c.Assert(err, IsNil) testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") - testKit.Se.GetSessionVars().EnableClusteredIndex = true + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn testKit.MustExec("drop table if exists t") testKit.MustExec("create table t(a int, b int, c int, primary key(a,b));") testKit.MustExec("insert into t values(1,1,1),(2,2,3)") @@ -1373,7 +1374,7 @@ func (s *testRangerSuite) TestCompIndexMultiColDNF2(c *C) { c.Assert(err, IsNil) testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") - testKit.Se.GetSessionVars().EnableClusteredIndex = true + testKit.Se.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOn testKit.MustExec("drop table if exists t") testKit.MustExec("create table t(a int, b int, c int, primary key(a,b,c));") testKit.MustExec("insert into t values(1,1,1),(2,2,3)")