From be8caa654965df2c3a8139f63bfe5cb264fd51af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 3 Jan 2023 15:50:19 +0800 Subject: [PATCH 1/9] ttl: disable ttl job when recover/flashback table/database/cluster (#40268) close pingcap/tidb#40265 --- ddl/cluster.go | 38 ++++++++++++++++++++++++--- ddl/cluster_test.go | 24 +++++++++++++++-- ddl/ddl_api.go | 3 ++- ddl/schema.go | 4 +++ ddl/serial_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++ ddl/table.go | 5 ++++ 6 files changed, 131 insertions(+), 7 deletions(-) diff --git a/ddl/cluster.go b/ddl/cluster.go index 227963b3951d5..cd0053f9e7e4f 100644 --- a/ddl/cluster.go +++ b/ddl/cluster.go @@ -68,6 +68,7 @@ const ( totalLockedRegionsOffset startTSOffset commitTSOffset + ttlJobEnableOffSet ) func closePDSchedule() error { @@ -124,6 +125,18 @@ func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBack return gcutil.ValidateSnapshotWithGCSafePoint(flashBackTS, gcSafePoint) } +func getTiDBTTLJobEnable(sess sessionctx.Context) (string, error) { + val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBTTLJobEnable) + if err != nil { + return "", errors.Trace(err) + } + return val, nil +} + +func setTiDBTTLJobEnable(ctx context.Context, sess sessionctx.Context, value string) error { + return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBTTLJobEnable, value) +} + func setTiDBEnableAutoAnalyze(ctx context.Context, sess sessionctx.Context, value string) error { return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBEnableAutoAnalyze, value) } @@ -176,6 +189,9 @@ func checkAndSetFlashbackClusterInfo(sess sessionctx.Context, d *ddlCtx, t *meta if err = setTiDBSuperReadOnly(d.ctx, sess, variable.On); err != nil { return err } + if err = setTiDBTTLJobEnable(d.ctx, sess, variable.Off); err != nil { + return err + } nowSchemaVersion, err := t.GetSchemaVersion() if err != nil { @@ -553,9 +569,9 @@ func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ve var flashbackTS, lockedRegions, startTS, commitTS uint64 var pdScheduleValue map[string]interface{} - var autoAnalyzeValue, readOnlyValue string + var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string var gcEnabledValue bool - if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabledValue, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS); err != nil { + if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabledValue, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue); err != nil { job.State = model.JobStateCancelled return ver, errors.Trace(err) } @@ -595,6 +611,12 @@ func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ve return ver, errors.Trace(err) } job.Args[readOnlyOffset] = &readOnlyValue + ttlJobEnableValue, err = getTiDBTTLJobEnable(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[ttlJobEnableOffSet] = &ttlJobEnableValue job.SchemaState = model.StateDeleteOnly return ver, nil // Stage 2, check flashbackTS, close GC and PD schedule. @@ -694,10 +716,10 @@ func finishFlashbackCluster(w *worker, job *model.Job) error { var flashbackTS, lockedRegions, startTS, commitTS uint64 var pdScheduleValue map[string]interface{} - var autoAnalyzeValue, readOnlyValue string + var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string var gcEnabled bool - if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabled, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS); err != nil { + if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabled, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue); err != nil { return errors.Trace(err) } sess, err := w.sessPool.get() @@ -718,6 +740,14 @@ func finishFlashbackCluster(w *worker, job *model.Job) error { if err = setTiDBSuperReadOnly(w.ctx, sess, readOnlyValue); err != nil { return err } + + if job.IsCancelled() { + // only restore `tidb_ttl_job_enable` when flashback failed + if err = setTiDBTTLJobEnable(w.ctx, sess, ttlJobEnableValue); err != nil { + return err + } + } + return setTiDBEnableAutoAnalyze(w.ctx, sess, autoAnalyzeValue) }) if err != nil { diff --git a/ddl/cluster_test.go b/ddl/cluster_test.go index 12c77c42edafe..8bb1776d6d08f 100644 --- a/ddl/cluster_test.go +++ b/ddl/cluster_test.go @@ -209,12 +209,16 @@ func TestGlobalVariablesOnFlashback(t *testing.T) { rs, err = tk.Exec("show variables like 'tidb_super_read_only'") assert.NoError(t, err) assert.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.On) + rs, err = tk.Exec("show variables like 'tidb_ttl_job_enable'") + assert.NoError(t, err) + assert.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.Off) } } dom.DDL().SetHook(hook) - // first try with `tidb_gc_enable` = on and `tidb_super_read_only` = off + // first try with `tidb_gc_enable` = on and `tidb_super_read_only` = off and `tidb_ttl_job_enable` = on tk.MustExec("set global tidb_gc_enable = on") tk.MustExec("set global tidb_super_read_only = off") + tk.MustExec("set global tidb_ttl_job_enable = on") tk.MustExec(fmt.Sprintf("flashback cluster to timestamp '%s'", oracle.GetTimeFromTS(ts))) @@ -224,10 +228,14 @@ func TestGlobalVariablesOnFlashback(t *testing.T) { rs, err = tk.Exec("show variables like 'tidb_gc_enable'") require.NoError(t, err) require.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.On) + rs, err = tk.Exec("show variables like 'tidb_ttl_job_enable'") + require.NoError(t, err) + require.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.Off) - // second try with `tidb_gc_enable` = off and `tidb_super_read_only` = on + // second try with `tidb_gc_enable` = off and `tidb_super_read_only` = on and `tidb_ttl_job_enable` = off tk.MustExec("set global tidb_gc_enable = off") tk.MustExec("set global tidb_super_read_only = on") + tk.MustExec("set global tidb_ttl_job_enable = off") ts, err = tk.Session().GetStore().GetOracle().GetTimestamp(context.Background(), &oracle.Option{}) require.NoError(t, err) @@ -238,6 +246,9 @@ func TestGlobalVariablesOnFlashback(t *testing.T) { rs, err = tk.Exec("show variables like 'tidb_gc_enable'") require.NoError(t, err) require.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.Off) + rs, err = tk.Exec("show variables like 'tidb_ttl_job_enable'") + assert.NoError(t, err) + assert.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.Off) dom.DDL().SetHook(originHook) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/mockFlashbackTest")) @@ -268,9 +279,14 @@ func TestCancelFlashbackCluster(t *testing.T) { return job.SchemaState == model.StateDeleteOnly }) dom.DDL().SetHook(hook) + tk.MustExec("set global tidb_ttl_job_enable = on") tk.MustGetErrCode(fmt.Sprintf("flashback cluster to timestamp '%s'", oracle.GetTimeFromTS(ts)), errno.ErrCancelledDDLJob) hook.MustCancelDone(t) + rs, err := tk.Exec("show variables like 'tidb_ttl_job_enable'") + assert.NoError(t, err) + assert.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.On) + // Try canceled on StateWriteReorganization, cancel failed hook = newCancelJobHook(t, store, dom, func(job *model.Job) bool { return job.SchemaState == model.StateWriteReorganization @@ -279,6 +295,10 @@ func TestCancelFlashbackCluster(t *testing.T) { tk.MustExec(fmt.Sprintf("flashback cluster to timestamp '%s'", oracle.GetTimeFromTS(ts))) hook.MustCancelFailed(t) + rs, err = tk.Exec("show variables like 'tidb_ttl_job_enable'") + assert.NoError(t, err) + assert.Equal(t, tk.ResultSetToResult(rs, "").Rows()[0][1], variable.Off) + dom.DDL().SetHook(originHook) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/mockFlashbackTest")) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 7cd65b47b170a..2a20ea69f3d39 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -2750,7 +2750,8 @@ func (d *ddl) FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error variable.Off, /* tidb_super_read_only */ 0, /* totalRegions */ 0, /* startTS */ - 0 /* commitTS */}, + 0, /* commitTS */ + variable.On /* tidb_ttl_job_enable */}, } err = d.DoDDLJob(ctx, job) err = d.callHookOnChanged(job, err) diff --git a/ddl/schema.go b/ddl/schema.go index d9e86c30c5eaf..e9cb1e6579635 100644 --- a/ddl/schema.go +++ b/ddl/schema.go @@ -312,6 +312,10 @@ func (w *worker) onRecoverSchema(d *ddlCtx, t *meta.Meta, job *model.Job) (ver i return ver, errors.Trace(err) } for _, recoverInfo := range recoverSchemaInfo.RecoverTabsInfo { + if recoverInfo.TableInfo.TTLInfo != nil { + // force disable TTL job schedule for recovered table + recoverInfo.TableInfo.TTLInfo.Enable = false + } ver, err = w.recoverTable(t, job, recoverInfo) if err != nil { return ver, errors.Trace(err) diff --git a/ddl/serial_test.go b/ddl/serial_test.go index e3456124871e8..315cfca73e57c 100644 --- a/ddl/serial_test.go +++ b/ddl/serial_test.go @@ -462,6 +462,70 @@ func TestCancelAddIndexPanic(t *testing.T) { require.Truef(t, strings.HasPrefix(errMsg, "[ddl:8214]Cancelled DDL job"), "%v", errMsg) } +func TestRecoverTableWithTTL(t *testing.T) { + store, _ := createMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create database if not exists test_recover") + tk.MustExec("use test_recover") + defer func(originGC bool) { + if originGC { + util.EmulatorGCEnable() + } else { + util.EmulatorGCDisable() + } + }(util.IsEmulatorGCEnable()) + + // disable emulator GC. + // Otherwise emulator GC will delete table record as soon as possible after execute drop table ddl. + util.EmulatorGCDisable() + gcTimeFormat := "20060102-15:04:05 -0700 MST" + safePointSQL := `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('tikv_gc_safe_point', '%[1]s', '') + ON DUPLICATE KEY + UPDATE variable_value = '%[1]s'` + tk.MustExec(fmt.Sprintf(safePointSQL, time.Now().Add(-time.Hour).Format(gcTimeFormat))) + getDDLJobID := func(table, tp string) int64 { + rs, err := tk.Exec("admin show ddl jobs") + require.NoError(t, err) + rows, err := session.GetRows4Test(context.Background(), tk.Session(), rs) + require.NoError(t, err) + for _, row := range rows { + if row.GetString(2) == table && row.GetString(3) == tp { + return row.GetInt64(0) + } + } + require.FailNowf(t, "can't find %s table of %s", tp, table) + return -1 + } + + // recover table + tk.MustExec("create table t_recover1 (t timestamp) TTL=`t`+INTERVAL 1 DAY") + tk.MustExec("drop table t_recover1") + tk.MustExec("recover table t_recover1") + tk.MustQuery("show create table t_recover1").Check(testkit.Rows("t_recover1 CREATE TABLE `t_recover1` (\n `t` timestamp NULL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![ttl] TTL=`t` + INTERVAL 1 DAY */ /*T![ttl] TTL_ENABLE='OFF' */")) + + // recover table with job id + tk.MustExec("create table t_recover2 (t timestamp) TTL=`t`+INTERVAL 1 DAY") + tk.MustExec("drop table t_recover2") + jobID := getDDLJobID("t_recover2", "drop table") + tk.MustExec(fmt.Sprintf("recover table BY JOB %d", jobID)) + tk.MustQuery("show create table t_recover2").Check(testkit.Rows("t_recover2 CREATE TABLE `t_recover2` (\n `t` timestamp NULL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![ttl] TTL=`t` + INTERVAL 1 DAY */ /*T![ttl] TTL_ENABLE='OFF' */")) + + // flashback table + tk.MustExec("create table t_recover3 (t timestamp) TTL=`t`+INTERVAL 1 DAY") + tk.MustExec("drop table t_recover3") + tk.MustExec("flashback table t_recover3") + tk.MustQuery("show create table t_recover3").Check(testkit.Rows("t_recover3 CREATE TABLE `t_recover3` (\n `t` timestamp NULL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![ttl] TTL=`t` + INTERVAL 1 DAY */ /*T![ttl] TTL_ENABLE='OFF' */")) + + // flashback database + tk.MustExec("create database if not exists test_recover2") + tk.MustExec("create table test_recover2.t1 (t timestamp) TTL=`t`+INTERVAL 1 DAY") + tk.MustExec("create table test_recover2.t2 (t timestamp) TTL=`t`+INTERVAL 1 DAY") + tk.MustExec("drop database test_recover2") + tk.MustExec("flashback database test_recover2") + tk.MustQuery("show create table test_recover2.t1").Check(testkit.Rows("t1 CREATE TABLE `t1` (\n `t` timestamp NULL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![ttl] TTL=`t` + INTERVAL 1 DAY */ /*T![ttl] TTL_ENABLE='OFF' */")) + tk.MustQuery("show create table test_recover2.t2").Check(testkit.Rows("t2 CREATE TABLE `t2` (\n `t` timestamp NULL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![ttl] TTL=`t` + INTERVAL 1 DAY */ /*T![ttl] TTL_ENABLE='OFF' */")) +} + func TestRecoverTableByJobID(t *testing.T) { store, _ := createMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) diff --git a/ddl/table.go b/ddl/table.go index 9e6fab762d3c5..a27eeb4df42fa 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -404,6 +404,11 @@ func (w *worker) onRecoverTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in schemaID := recoverInfo.SchemaID tblInfo := recoverInfo.TableInfo + if tblInfo.TTLInfo != nil { + // force disable TTL job schedule for recovered table + tblInfo.TTLInfo.Enable = false + } + // check GC and safe point gcEnable, err := checkGCEnable(w) if err != nil { From 494672cb51989eddc65c177da6036e00483db1cb Mon Sep 17 00:00:00 2001 From: xiongjiwei Date: Tue, 3 Jan 2023 17:58:20 +0900 Subject: [PATCH 2/9] admin: impl admin check index for mv index (#40270) close pingcap/tidb#40272 --- executor/admin_test.go | 59 +++++++++++++++++++++++ executor/builder.go | 3 +- executor/distsql.go | 22 +++------ executor/executor.go | 19 +++++++- parser/types/field_type.go | 3 ++ table/tables/index.go | 81 ++++++++++++++++---------------- table/tables/mutation_checker.go | 50 +++++++++++--------- 7 files changed, 159 insertions(+), 78 deletions(-) diff --git a/executor/admin_test.go b/executor/admin_test.go index 0b2530e76d5a3..cd5c0664d031a 100644 --- a/executor/admin_test.go +++ b/executor/admin_test.go @@ -843,6 +843,65 @@ func TestClusteredAdminCleanupIndex(t *testing.T) { tk.MustExec("admin check table admin_test") } +func TestAdminCheckTableWithMultiValuedIndex(t *testing.T) { + store, domain := testkit.CreateMockStoreAndDomain(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int primary key, a json, index idx((cast(a as signed array))))") + tk.MustExec("insert into t values (0, '[0,1,2]')") + tk.MustExec("insert into t values (1, '[1,2,3]')") + tk.MustExec("insert into t values (2, '[2,3,4]')") + tk.MustExec("insert into t values (3, '[3,4,5]')") + tk.MustExec("insert into t values (4, '[4,5,6]')") + tk.MustExec("admin check table t") + + // Make some corrupted index. Build the index information. + ctx := mock.NewContext() + ctx.Store = store + is := domain.InfoSchema() + dbName := model.NewCIStr("test") + tblName := model.NewCIStr("t") + tbl, err := is.TableByName(dbName, tblName) + require.NoError(t, err) + tblInfo := tbl.Meta() + idxInfo := tblInfo.Indices[0] + sc := ctx.GetSessionVars().StmtCtx + tk.Session().GetSessionVars().IndexLookupSize = 3 + tk.Session().GetSessionVars().MaxChunkSize = 3 + + cpIdx := idxInfo.Clone() + cpIdx.MVIndex = false + indexOpr := tables.NewIndex(tblInfo.ID, tblInfo, cpIdx) + txn, err := store.Begin() + require.NoError(t, err) + err = indexOpr.Delete(sc, txn, types.MakeDatums(0), kv.IntHandle(0)) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + err = tk.ExecToErr("admin check table t") + require.Error(t, err) + require.True(t, consistency.ErrAdminCheckInconsistent.Equal(err)) + + txn, err = store.Begin() + require.NoError(t, err) + _, err = indexOpr.Create(ctx, txn, types.MakeDatums(0), kv.IntHandle(0), nil) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + tk.MustExec("admin check table t") + + txn, err = store.Begin() + require.NoError(t, err) + _, err = indexOpr.Create(ctx, txn, types.MakeDatums(9), kv.IntHandle(9), nil) + require.NoError(t, err) + err = txn.Commit(context.Background()) + require.NoError(t, err) + err = tk.ExecToErr("admin check table t") + require.Error(t, err) +} + func TestAdminCheckPartitionTableFailed(t *testing.T) { store, domain := testkit.CreateMockStoreAndDomain(t) diff --git a/executor/builder.go b/executor/builder.go index c014893a0c86c..137fca5f0b8f5 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -433,7 +433,8 @@ func buildIndexLookUpChecker(b *executorBuilder, p *plannercore.PhysicalIndexLoo tps := make([]*types.FieldType, 0, fullColLen) for _, col := range is.Columns { - tps = append(tps, &(col.FieldType)) + // tps is used to decode the index, we should use the element type of the array if any. + tps = append(tps, col.FieldType.ArrayType()) } if !e.isCommonHandle() { diff --git a/executor/distsql.go b/executor/distsql.go index aab5067a81b6a..3b9a6a7d4b288 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -39,6 +39,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" @@ -1254,36 +1255,27 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta sctx := w.idxLookup.ctx.GetSessionVars().StmtCtx for i := range vals { col := w.idxTblCols[i] - tp := &col.FieldType - idxVal := idxRow.GetDatum(i, tp) + idxVal := idxRow.GetDatum(i, w.idxColTps[i]) tablecodec.TruncateIndexValue(&idxVal, w.idxLookup.index.Columns[i], col.ColumnInfo) - cmpRes, err := idxVal.Compare(sctx, &vals[i], collators[i]) + cmpRes, err := tables.CompareIndexAndVal(sctx, vals[i], idxVal, collators[i], col.FieldType.IsArray() && vals[i].Kind() == types.KindMysqlJSON) if err != nil { - fts := make([]*types.FieldType, 0, len(w.idxTblCols)) - for _, c := range w.idxTblCols { - fts = append(fts, &c.FieldType) - } return ir().ReportAdminCheckInconsistentWithColInfo(ctx, handle, col.Name.O, - idxRow.GetDatum(i, tp), + idxVal, vals[i], err, - &consistency.RecordData{Handle: handle, Values: getDatumRow(&idxRow, fts)}, + &consistency.RecordData{Handle: handle, Values: getDatumRow(&idxRow, w.idxColTps)}, ) } if cmpRes != 0 { - fts := make([]*types.FieldType, 0, len(w.idxTblCols)) - for _, c := range w.idxTblCols { - fts = append(fts, &c.FieldType) - } return ir().ReportAdminCheckInconsistentWithColInfo(ctx, handle, col.Name.O, - idxRow.GetDatum(i, tp), + idxRow.GetDatum(i, w.idxColTps[i]), vals[i], err, - &consistency.RecordData{Handle: handle, Values: getDatumRow(&idxRow, fts)}, + &consistency.RecordData{Handle: handle, Values: getDatumRow(&idxRow, w.idxColTps)}, ) } } diff --git a/executor/executor.go b/executor/executor.go index 603996ad7764f..1679ed9e57e12 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -959,6 +959,9 @@ func (e *CheckTableExec) Next(ctx context.Context, req *chunk.Chunk) error { idxNames := make([]string, 0, len(e.indexInfos)) for _, idx := range e.indexInfos { + if idx.MVIndex { + continue + } idxNames = append(idxNames, idx.Name.O) } greater, idxOffset, err := admin.CheckIndicesCount(e.ctx, e.dbName, e.table.Meta().Name.O, idxNames) @@ -978,7 +981,13 @@ func (e *CheckTableExec) Next(ctx context.Context, req *chunk.Chunk) error { // The number of table rows is equal to the number of index rows. // TODO: Make the value of concurrency adjustable. And we can consider the number of records. if len(e.srcs) == 1 { - return e.checkIndexHandle(ctx, e.srcs[0]) + err = e.checkIndexHandle(ctx, e.srcs[0]) + if err == nil && e.srcs[0].index.MVIndex { + err = e.checkTableRecord(ctx, 0) + } + if err != nil { + return err + } } taskCh := make(chan *IndexLookUpExecutor, len(e.srcs)) failure := atomicutil.NewBool(false) @@ -997,6 +1006,14 @@ func (e *CheckTableExec) Next(ctx context.Context, req *chunk.Chunk) error { select { case src := <-taskCh: err1 := e.checkIndexHandle(ctx, src) + if err1 == nil && src.index.MVIndex { + for offset, idx := range e.indexInfos { + if idx.ID == src.index.ID { + err1 = e.checkTableRecord(ctx, offset) + break + } + } + } if err1 != nil { failure.Store(true) logutil.Logger(ctx).Info("check index handle failed", zap.Error(err1)) diff --git a/parser/types/field_type.go b/parser/types/field_type.go index 991dc3d087d75..464ba38a6cb7c 100644 --- a/parser/types/field_type.go +++ b/parser/types/field_type.go @@ -235,6 +235,9 @@ func (ft *FieldType) IsArray() bool { // ArrayType return the type of the array. func (ft *FieldType) ArrayType() *FieldType { + if !ft.array { + return ft + } clone := ft.Clone() clone.SetArray(false) return clone diff --git a/table/tables/index.go b/table/tables/index.go index 607afb9640aad..265fabf966f7a 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -438,55 +438,56 @@ func GenTempIdxKeyByState(indexInfo *model.IndexInfo, indexKey kv.Key) (key, tem return indexKey, nil, TempIndexKeyTypeNone } -func (c *index) Exist(sc *stmtctx.StatementContext, txn kv.Transaction, indexedValues []types.Datum, h kv.Handle) (bool, kv.Handle, error) { - key, distinct, err := c.GenIndexKey(sc, indexedValues, h, nil) - if err != nil { - return false, nil, err - } - - var ( - tempKey []byte - keyVer byte - ) - // If index current is in creating status and using ingest mode, we need first - // check key exist status in temp index. - key, tempKey, keyVer = GenTempIdxKeyByState(c.idxInfo, key) - if keyVer != TempIndexKeyTypeNone { - KeyExistInfo, h1, err1 := KeyExistInTempIndex(context.TODO(), txn, tempKey, distinct, h, c.tblInfo.IsCommonHandle) - if err1 != nil { +func (c *index) Exist(sc *stmtctx.StatementContext, txn kv.Transaction, indexedValue []types.Datum, h kv.Handle) (bool, kv.Handle, error) { + indexedValues := c.getIndexedValue(indexedValue) + for _, val := range indexedValues { + key, distinct, err := c.GenIndexKey(sc, val, h, nil) + if err != nil { return false, nil, err } - switch KeyExistInfo { - case KeyInTempIndexNotExist, KeyInTempIndexIsDeleted: - return false, nil, nil - case KeyInTempIndexConflict: - return true, h1, kv.ErrKeyExists - case KeyInTempIndexIsItself: - return true, h, nil - } - } - value, err := txn.Get(context.TODO(), key) - if kv.IsErrNotFound(err) { - return false, nil, nil - } - if err != nil { - return false, nil, err - } + var ( + tempKey []byte + keyVer byte + ) + // If index current is in creating status and using ingest mode, we need first + // check key exist status in temp index. + key, tempKey, keyVer = GenTempIdxKeyByState(c.idxInfo, key) + if keyVer != TempIndexKeyTypeNone { + KeyExistInfo, h1, err1 := KeyExistInTempIndex(context.TODO(), txn, tempKey, distinct, h, c.tblInfo.IsCommonHandle) + if err1 != nil { + return false, nil, err + } + switch KeyExistInfo { + case KeyInTempIndexNotExist, KeyInTempIndexIsDeleted: + return false, nil, nil + case KeyInTempIndexConflict: + return true, h1, kv.ErrKeyExists + case KeyInTempIndexIsItself: + continue + } + } - // For distinct index, the value of key is handle. - if distinct { - var handle kv.Handle - handle, err := tablecodec.DecodeHandleInUniqueIndexValue(value, c.tblInfo.IsCommonHandle) + value, err := txn.Get(context.TODO(), key) + if kv.IsErrNotFound(err) { + return false, nil, nil + } if err != nil { return false, nil, err } - if !handle.Equal(h) { - return true, handle, kv.ErrKeyExists + + // For distinct index, the value of key is handle. + if distinct { + var handle kv.Handle + handle, err := tablecodec.DecodeHandleInUniqueIndexValue(value, c.tblInfo.IsCommonHandle) + if err != nil { + return false, nil, err + } + if !handle.Equal(h) { + return true, handle, kv.ErrKeyExists + } } - return true, handle, nil } - return true, h, nil } diff --git a/table/tables/mutation_checker.go b/table/tables/mutation_checker.go index 8445a266e3d2d..328989d88ad3f 100644 --- a/table/tables/mutation_checker.go +++ b/table/tables/mutation_checker.go @@ -347,27 +347,11 @@ func compareIndexData( cols[indexInfo.Columns[i].Offset].ColumnInfo, ) - var comparison int - var err error - // If it is multi-valued index, we should check the JSON contains the indexed value. - if cols[indexInfo.Columns[i].Offset].ColumnInfo.FieldType.IsArray() && expectedDatum.Kind() == types.KindMysqlJSON { - bj := expectedDatum.GetMysqlJSON() - count := bj.GetElemCount() - for elemIdx := 0; elemIdx < count; elemIdx++ { - jsonDatum := types.NewJSONDatum(bj.ArrayGetElem(elemIdx)) - comparison, err = jsonDatum.Compare(sc, &decodedMutationDatum, collate.GetBinaryCollator()) - if err != nil { - return errors.Trace(err) - } - if comparison == 0 { - break - } - } - } else { - comparison, err = decodedMutationDatum.Compare(sc, &expectedDatum, collate.GetCollator(decodedMutationDatum.Collation())) - if err != nil { - return errors.Trace(err) - } + comparison, err := CompareIndexAndVal(sc, expectedDatum, decodedMutationDatum, + collate.GetCollator(decodedMutationDatum.Collation()), + cols[indexInfo.Columns[i].Offset].ColumnInfo.FieldType.IsArray() && expectedDatum.Kind() == types.KindMysqlJSON) + if err != nil { + return errors.Trace(err) } if comparison != 0 { @@ -382,6 +366,30 @@ func compareIndexData( return nil } +// CompareIndexAndVal compare index valued and row value. +func CompareIndexAndVal(sctx *stmtctx.StatementContext, rowVal types.Datum, idxVal types.Datum, collator collate.Collator, cmpMVIndex bool) (int, error) { + var cmpRes int + var err error + if cmpMVIndex { + // If it is multi-valued index, we should check the JSON contains the indexed value. + bj := rowVal.GetMysqlJSON() + count := bj.GetElemCount() + for elemIdx := 0; elemIdx < count; elemIdx++ { + jsonDatum := types.NewJSONDatum(bj.ArrayGetElem(elemIdx)) + cmpRes, err = jsonDatum.Compare(sctx, &idxVal, collate.GetBinaryCollator()) + if err != nil { + return 0, errors.Trace(err) + } + if cmpRes == 0 { + break + } + } + } else { + cmpRes, err = idxVal.Compare(sctx, &rowVal, collator) + } + return cmpRes, err +} + // getColumnMaps tries to get the columnMaps from transaction options. If there isn't one, it builds one and stores it. // It saves redundant computations of the map. func getColumnMaps(txn kv.Transaction, t *TableCommon) columnMaps { From 5327d07afc73618ac802af5928e8f9ffdbedc8c3 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Tue, 3 Jan 2023 17:30:20 +0800 Subject: [PATCH 3/9] planner: refactor plan-cache UseCache flag (#40256) ref pingcap/tidb#36598 --- executor/executor_test.go | 16 ++++++++-------- parser/ast/misc.go | 1 - planner/core/plan_cache.go | 13 ++++++++----- planner/core/plan_cache_test.go | 16 ++++++++++++++++ planner/core/plan_cache_utils.go | 16 ++++++++++++---- planner/core/prepare_test.go | 2 +- 6 files changed, 45 insertions(+), 19 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index 858a4cc9372f9..9b602d2c151f9 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -3621,10 +3621,10 @@ func TestPointGetPreparedPlan(t *testing.T) { pspk1Id, _, _, err := tk.Session().PrepareStmt("select * from t where a = ?") require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[pspk1Id].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[pspk1Id].(*plannercore.PlanCacheStmt).StmtCacheable = false pspk2Id, _, _, err := tk.Session().PrepareStmt("select * from t where ? = a ") require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[pspk2Id].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[pspk2Id].(*plannercore.PlanCacheStmt).StmtCacheable = false ctx := context.Background() // first time plan generated @@ -3664,7 +3664,7 @@ func TestPointGetPreparedPlan(t *testing.T) { // unique index psuk1Id, _, _, err := tk.Session().PrepareStmt("select * from t where b = ? ") require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[psuk1Id].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[psuk1Id].(*plannercore.PlanCacheStmt).StmtCacheable = false rs, err = tk.Session().ExecutePreparedStmt(ctx, psuk1Id, expression.Args2Expressions4Test(1)) require.NoError(t, err) @@ -3782,7 +3782,7 @@ func TestPointGetPreparedPlanWithCommitMode(t *testing.T) { pspk1Id, _, _, err := tk1.Session().PrepareStmt("select * from t where a = ?") require.NoError(t, err) - tk1.Session().GetSessionVars().PreparedStmts[pspk1Id].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk1.Session().GetSessionVars().PreparedStmts[pspk1Id].(*plannercore.PlanCacheStmt).StmtCacheable = false ctx := context.Background() // first time plan generated @@ -3848,11 +3848,11 @@ func TestPointUpdatePreparedPlan(t *testing.T) { updateID1, pc, _, err := tk.Session().PrepareStmt(`update t set c = c + 1 where a = ?`) require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[updateID1].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[updateID1].(*plannercore.PlanCacheStmt).StmtCacheable = false require.Equal(t, 1, pc) updateID2, pc, _, err := tk.Session().PrepareStmt(`update t set c = c + 2 where ? = a`) require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[updateID2].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[updateID2].(*plannercore.PlanCacheStmt).StmtCacheable = false require.Equal(t, 1, pc) ctx := context.Background() @@ -3887,7 +3887,7 @@ func TestPointUpdatePreparedPlan(t *testing.T) { // unique index updUkID1, _, _, err := tk.Session().PrepareStmt(`update t set c = c + 10 where b = ?`) require.NoError(t, err) - tk.Session().GetSessionVars().PreparedStmts[updUkID1].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk.Session().GetSessionVars().PreparedStmts[updUkID1].(*plannercore.PlanCacheStmt).StmtCacheable = false rs, err = tk.Session().ExecutePreparedStmt(ctx, updUkID1, expression.Args2Expressions4Test(3)) require.Nil(t, rs) require.NoError(t, err) @@ -3956,7 +3956,7 @@ func TestPointUpdatePreparedPlanWithCommitMode(t *testing.T) { ctx := context.Background() updateID1, _, _, err := tk1.Session().PrepareStmt(`update t set c = c + 1 where a = ?`) - tk1.Session().GetSessionVars().PreparedStmts[updateID1].(*plannercore.PlanCacheStmt).PreparedAst.UseCache = false + tk1.Session().GetSessionVars().PreparedStmts[updateID1].(*plannercore.PlanCacheStmt).StmtCacheable = false require.NoError(t, err) // first time plan generated diff --git a/parser/ast/misc.go b/parser/ast/misc.go index 7a0e2fc7a1a50..bfa105365700d 100644 --- a/parser/ast/misc.go +++ b/parser/ast/misc.go @@ -520,7 +520,6 @@ type Prepared struct { StmtType string Params []ParamMarkerExpr SchemaVersion int64 - UseCache bool CachedPlan interface{} CachedNames interface{} } diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index 8036f4067ce65..84be5bcaf9ab9 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -100,7 +100,7 @@ func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isNonPrep // So we need to clear the current session's plan cache. // And update lastUpdateTime to the newest one. expiredTimeStamp4PC := domain.GetDomain(sctx).ExpiredTimeStamp4PC() - if stmtAst.UseCache && expiredTimeStamp4PC.Compare(vars.LastUpdateTime4PC) > 0 { + if stmt.StmtCacheable && expiredTimeStamp4PC.Compare(vars.LastUpdateTime4PC) > 0 { sctx.GetPlanCache(isNonPrepared).DeleteAll() stmtAst.CachedPlan = nil vars.LastUpdateTime4PC = expiredTimeStamp4PC @@ -127,7 +127,10 @@ func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, sessVars := sctx.GetSessionVars() stmtCtx := sessVars.StmtCtx stmtAst := stmt.PreparedAst - stmtCtx.UseCache = stmtAst.UseCache + stmtCtx.UseCache = stmt.StmtCacheable + if !stmt.StmtCacheable { + stmtCtx.SetSkipPlanCache(errors.Errorf("skip plan-cache: %s", stmt.UncacheableReason)) + } var bindSQL string var ignorePlanCache = false @@ -136,7 +139,7 @@ func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, // rebuild the plan. So we set this value in rc or for update read. In other cases, let it be 0. var latestSchemaVersion int64 - if stmtAst.UseCache { + if stmtCtx.UseCache { bindSQL, ignorePlanCache = GetBindSQL4PlanCache(sctx, stmt) if sctx.GetSessionVars().IsIsolation(ast.ReadCommitted) || stmt.ForUpdateRead { // In Rc or ForUpdateRead, we should check if the information schema has been changed since @@ -152,13 +155,13 @@ func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, paramNum, paramTypes := parseParamTypes(sctx, params) - if stmtAst.UseCache && stmtAst.CachedPlan != nil && !ignorePlanCache { // for point query plan + if stmtCtx.UseCache && stmtAst.CachedPlan != nil && !ignorePlanCache { // for point query plan if plan, names, ok, err := getCachedPointPlan(stmtAst, sessVars, stmtCtx); ok { return plan, names, err } } - if stmtAst.UseCache && !ignorePlanCache { // for non-point plans + if stmtCtx.UseCache && !ignorePlanCache { // for non-point plans if plan, names, ok, err := getCachedPlan(sctx, isNonPrepared, cacheKey, bindSQL, is, stmt, paramTypes); err != nil || ok { return plan, names, err diff --git a/planner/core/plan_cache_test.go b/planner/core/plan_cache_test.go index e25565a110e08..4fe0e6cf153dd 100644 --- a/planner/core/plan_cache_test.go +++ b/planner/core/plan_cache_test.go @@ -384,3 +384,19 @@ func TestPlanCacheDiagInfo(t *testing.T) { tk.MustExec("execute stmt using @a, @b") // a=1 and a=1 -> a=1 tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: some parameters may be overwritten")) } + +func TestUncacheableReason(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (a int)") + + tk.MustExec("prepare st from 'select * from t limit ?'") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?' is un-cacheable")) + + tk.MustExec("set @a=1") + tk.MustQuery("execute st using @a").Check(testkit.Rows()) + tk.MustExec("prepare st from 'select * from t limit ?'") + // show the corresponding un-cacheable reason at execute-stage as well + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1105 skip plan-cache: query has 'limit ?' is un-cacheable")) +} diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index 6408d269ef799..3fe4ee38bfe45 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -119,12 +119,14 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, var ( normalizedSQL4PC, digest4PC string selectStmtNode ast.StmtNode + cacheable bool + reason string ) if !vars.EnablePreparedPlanCache { - prepared.UseCache = false + cacheable = false + reason = "plan cache is disabled" } else { - cacheable, reason := CacheableWithCtx(sctx, paramStmt, ret.InfoSchema) - prepared.UseCache = cacheable + cacheable, reason = CacheableWithCtx(sctx, paramStmt, ret.InfoSchema) if !cacheable { sctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("skip plan-cache: " + reason)) } @@ -160,6 +162,8 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, SnapshotTSEvaluator: ret.SnapshotTSEvaluator, NormalizedSQL4PC: normalizedSQL4PC, SQLDigest4PC: digest4PC, + StmtCacheable: cacheable, + UncacheableReason: reason, } if err = CheckPreparedPriv(sctx, preparedObj, ret.InfoSchema); err != nil { return nil, nil, 0, err @@ -412,7 +416,11 @@ type PlanCacheStmt struct { // Executor is only used for point get scene. // Notice that we should only cache the PointGetExecutor that have a snapshot with MaxTS in it. // If the current plan is not PointGet or does not use MaxTS optimization, this value should be nil here. - Executor interface{} + Executor interface{} + + StmtCacheable bool // Whether this stmt is cacheable. + UncacheableReason string // Why this stmt is uncacheable. + NormalizedSQL string NormalizedPlan string SQLDigest *parser.Digest diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 71eb4c997211d..656aed73ca189 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -60,7 +60,7 @@ func TestPointGetPreparedPlan4PlanCache(t *testing.T) { pspk1Id, _, _, err := tk1.Session().PrepareStmt("select * from t where a = ?") require.NoError(t, err) - tk1.Session().GetSessionVars().PreparedStmts[pspk1Id].(*core.PlanCacheStmt).PreparedAst.UseCache = false + tk1.Session().GetSessionVars().PreparedStmts[pspk1Id].(*core.PlanCacheStmt).StmtCacheable = false ctx := context.Background() // first time plan generated From 9aaa93e9ccd901f13dbf9e09b2c989ecdf29ae6a Mon Sep 17 00:00:00 2001 From: xiongjiwei Date: Tue, 3 Jan 2023 19:04:19 +0900 Subject: [PATCH 4/9] test: fix data race in cast as array (#40277) close pingcap/tidb#40276 --- expression/builtin_cast.go | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 545abd497a2da..b6fc16cda2ef4 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -458,6 +458,9 @@ func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc { return newSig } +// fakeSctx is used to ignore the sql mode, `cast as array` should always return error if any. +var fakeSctx = &stmtctx.StatementContext{InInsertStmt: true} + func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { val, isNull, err := b.args[0].EvalJSON(b.ctx, row) if isNull || err != nil { @@ -474,20 +477,8 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS if f == nil { return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAS-ing JSON to the target type") } - sc := b.ctx.GetSessionVars().StmtCtx - originalOverflowAsWarning := sc.OverflowAsWarning - originIgnoreTruncate := sc.IgnoreTruncate - originTruncateAsWarning := sc.TruncateAsWarning - sc.OverflowAsWarning = false - sc.IgnoreTruncate = false - sc.TruncateAsWarning = false - defer func() { - sc.OverflowAsWarning = originalOverflowAsWarning - sc.IgnoreTruncate = originIgnoreTruncate - sc.TruncateAsWarning = originTruncateAsWarning - }() for i := 0; i < val.GetElemCount(); i++ { - item, err := f(sc, val.ArrayGetElem(i), ft) + item, err := f(fakeSctx, val.ArrayGetElem(i), ft) if err != nil { return types.BinaryJSON{}, false, err } From 4f1adb9e7cd79a75c83e8b7a3ac2fc7f860c221e Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Tue, 3 Jan 2023 19:14:19 +0800 Subject: [PATCH 5/9] planner: support converting `json_overlaps/contains` to IndexMerge to access MVIndex (#40195) ref pingcap/tidb#40191 --- planner/core/indexmerge_path.go | 78 ++++++++++- .../core/testdata/index_merge_suite_in.json | 14 +- .../core/testdata/index_merge_suite_out.json | 122 ++++++++++++++++++ 3 files changed, 210 insertions(+), 4 deletions(-) diff --git a/planner/core/indexmerge_path.go b/planner/core/indexmerge_path.go index f0ecf02a00231..4187e91395b0f 100644 --- a/planner/core/indexmerge_path.go +++ b/planner/core/indexmerge_path.go @@ -26,7 +26,9 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/util" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/ranger" "go.uber.org/zap" @@ -552,6 +554,7 @@ func (ds *DataSource) generateIndexMergeJSONMVIndexPath(normalPathCnt int, filte var jsonPath expression.Expression var vals []expression.Expression + var indexMergeIsIntersection bool switch sf.FuncName.L { case ast.JSONMemberOf: // (1 member of a->'$.zip') jsonPath = sf.GetArgs()[1] @@ -560,10 +563,29 @@ func (ds *DataSource) generateIndexMergeJSONMVIndexPath(normalPathCnt int, filte continue } vals = append(vals, v) - case ast.JSONOverlaps: // (json_overlaps(a->'$.zip', '[1, 2, 3]') - continue // TODO: support json_overlaps case ast.JSONContains: // (json_contains(a->'$.zip', '[1, 2, 3]') - continue // TODO: support json_contains + indexMergeIsIntersection = true + jsonPath = sf.GetArgs()[0] + var ok bool + vals, ok = jsonArrayExpr2Exprs(ds.ctx, sf.GetArgs()[1]) + if !ok { + continue + } + case ast.JSONOverlaps: // (json_overlaps(a->'$.zip', '[1, 2, 3]') + var jsonPathIdx int + if sf.GetArgs()[0].Equal(ds.ctx, targetJSONPath) { + jsonPathIdx = 0 // (json_overlaps(a->'$.zip', '[1, 2, 3]') + } else if sf.GetArgs()[1].Equal(ds.ctx, targetJSONPath) { + jsonPathIdx = 1 // (json_overlaps('[1, 2, 3]', a->'$.zip') + } else { + continue + } + jsonPath = sf.GetArgs()[jsonPathIdx] + var ok bool + vals, ok = jsonArrayExpr2Exprs(ds.ctx, sf.GetArgs()[1-jsonPathIdx]) + if !ok { + continue + } default: continue } @@ -612,12 +634,62 @@ func (ds *DataSource) generateIndexMergeJSONMVIndexPath(normalPathCnt int, filte partialPaths = append(partialPaths, partialPath) } indexMergePath := ds.buildIndexMergeOrPath(filters, partialPaths, filterIdx) + indexMergePath.IndexMergeIsIntersection = indexMergeIsIntersection mvIndexPaths = append(mvIndexPaths, indexMergePath) } } return } +// jsonArrayExpr2Exprs converts a JsonArray expression to expression list: cast('[1, 2, 3]' as JSON) --> []expr{1, 2, 3} +func jsonArrayExpr2Exprs(sctx sessionctx.Context, jsonArrayExpr expression.Expression) ([]expression.Expression, bool) { + // only support cast(const as JSON) + arrayExpr, wrappedByJSONCast := unwrapJSONCast(jsonArrayExpr) + if !wrappedByJSONCast { + return nil, false + } + if _, isConst := arrayExpr.(*expression.Constant); !isConst { + return nil, false + } + if expression.IsMutableEffectsExpr(arrayExpr) { + return nil, false + } + + jsonArray, isNull, err := jsonArrayExpr.EvalJSON(sctx, chunk.Row{}) + if isNull || err != nil { + return nil, false + } + if jsonArray.TypeCode != types.JSONTypeCodeArray { + single, ok := jsonValue2Expr(jsonArray) // '1' -> []expr{1} + if ok { + return []expression.Expression{single}, true + } + return nil, false + } + var exprs []expression.Expression + for i := 0; i < jsonArray.GetElemCount(); i++ { // '[1, 2, 3]' -> []expr{1, 2, 3} + expr, ok := jsonValue2Expr(jsonArray.ArrayGetElem(i)) + if !ok { + return nil, false + } + exprs = append(exprs, expr) + } + return exprs, true +} + +func jsonValue2Expr(v types.BinaryJSON) (expression.Expression, bool) { + if v.TypeCode != types.JSONTypeCodeInt64 { + // only support INT now + // TODO: support more types + return nil, false + } + val := v.GetInt64() + return &expression.Constant{ + Value: types.NewDatum(val), + RetType: types.NewFieldType(mysql.TypeLonglong), + }, true +} + func unwrapJSONCast(expr expression.Expression) (expression.Expression, bool) { if expr == nil { return nil, false diff --git a/planner/core/testdata/index_merge_suite_in.json b/planner/core/testdata/index_merge_suite_in.json index 2841de33bae0c..8865be189d702 100644 --- a/planner/core/testdata/index_merge_suite_in.json +++ b/planner/core/testdata/index_merge_suite_in.json @@ -6,7 +6,19 @@ "select /*+ use_index_merge(t, j0_1) */ * from t where (1 member of (j0->'$.path1')) and a<10", "select /*+ use_index_merge(t, j0_1) */ * from t where (1 member of (j0->'$.XXX')) and a<10", "select /*+ use_index_merge(t, j0_1) */ * from t where (1 member of (j0->'$.path1')) and (2 member of (j1)) and a<10", - "select /*+ use_index_merge(t, j1) */ * from t where (1 member of (j0->'$.path1')) and (2 member of (j1)) and a<10" + "select /*+ use_index_merge(t, j1) */ * from t where (1 member of (j0->'$.path1')) and (2 member of (j1)) and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '[1, 2, 3]')", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '[1, 2, 3]')", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('[1, 2, 3]', (j0->'$.path0'))", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '[1, 2, 3]') and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '[1, 2, 3]') and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('[1, 2, 3]', (j0->'$.path0')) and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '1')", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '1')", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('1', (j0->'$.path0'))", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '1') and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '1') and a<10", + "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('1', (j0->'$.path0')) and a<10" ] }, { diff --git a/planner/core/testdata/index_merge_suite_out.json b/planner/core/testdata/index_merge_suite_out.json index 31427fbf4c7e0..e8d0b00a4fe1e 100644 --- a/planner/core/testdata/index_merge_suite_out.json +++ b/planner/core/testdata/index_merge_suite_out.json @@ -49,6 +49,128 @@ " └─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '[1, 2, 3]')", + "Plan": [ + "IndexMerge 0.00 root type: intersection", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + "└─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '[1, 2, 3]')", + "Plan": [ + "Selection 0.00 root json_overlaps(json_extract(test.t.j0, \"$.path0\"), cast(\"[1, 2, 3]\", json BINARY))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('[1, 2, 3]', (j0->'$.path0'))", + "Plan": [ + "Selection 0.00 root json_overlaps(cast(\"[1, 2, 3]\", json BINARY), json_extract(test.t.j0, \"$.path0\"))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '[1, 2, 3]') and a<10", + "Plan": [ + "IndexMerge 0.00 root type: intersection", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + "└─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '[1, 2, 3]') and a<10", + "Plan": [ + "Selection 0.00 root json_overlaps(json_extract(test.t.j0, \"$.path0\"), cast(\"[1, 2, 3]\", json BINARY))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('[1, 2, 3]', (j0->'$.path0')) and a<10", + "Plan": [ + "Selection 0.00 root json_overlaps(cast(\"[1, 2, 3]\", json BINARY), json_extract(test.t.j0, \"$.path0\"))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[2,2], keep order:false, stats:pseudo", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[3,3], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '1')", + "Plan": [ + "IndexMerge 0.00 root type: intersection", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + "└─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '1')", + "Plan": [ + "Selection 0.00 root json_overlaps(json_extract(test.t.j0, \"$.path0\"), cast(\"1\", json BINARY))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('1', (j0->'$.path0'))", + "Plan": [ + "Selection 0.00 root json_overlaps(cast(\"1\", json BINARY), json_extract(test.t.j0, \"$.path0\"))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " └─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_contains((j0->'$.path0'), '1') and a<10", + "Plan": [ + "IndexMerge 0.00 root type: intersection", + "├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + "└─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps((j0->'$.path0'), '1') and a<10", + "Plan": [ + "Selection 0.00 root json_overlaps(json_extract(test.t.j0, \"$.path0\"), cast(\"1\", json BINARY))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select /*+ use_index_merge(t, j0_0) */ * from t where json_overlaps('1', (j0->'$.path0')) and a<10", + "Plan": [ + "Selection 0.00 root json_overlaps(cast(\"1\", json BINARY), json_extract(test.t.j0, \"$.path0\"))", + "└─IndexMerge 0.00 root type: union", + " ├─IndexRangeScan(Build) 10.00 cop[tikv] table:t, index:j0_0(cast(json_extract(`j0`, _utf8mb4'$.path0') as signed array)) range:[1,1], keep order:false, stats:pseudo", + " └─Selection(Probe) 0.00 cop[tikv] lt(test.t.a, 10)", + " └─TableRowIDScan 0.00 cop[tikv] table:t keep order:false, stats:pseudo" + ] } ] }, From 1f344ba1081ebf66a8e45ee01c09853f79aa1458 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 3 Jan 2023 19:40:19 +0800 Subject: [PATCH 6/9] autoid_service: add unit test for the package (#40193) --- autoid_service/BUILD.bazel | 18 +- autoid_service/autoid.go | 13 +- autoid_service/autoid_test.go | 202 ++++++++++++++++++++++ ddl/BUILD.bazel | 1 + ddl/db_integration_test.go | 1 + executor/autoidtest/BUILD.bazel | 1 + executor/autoidtest/autoid_test.go | 1 + executor/issuetest/BUILD.bazel | 1 + executor/issuetest/executor_issue_test.go | 1 + executor/showtest/BUILD.bazel | 1 + executor/showtest/show_test.go | 1 + meta/autoid/BUILD.bazel | 1 - meta/autoid/autoid.go | 9 +- session/BUILD.bazel | 1 + session/bench_test.go | 1 + sessionctx/binloginfo/BUILD.bazel | 1 + sessionctx/binloginfo/binloginfo_test.go | 1 + telemetry/BUILD.bazel | 1 + telemetry/data_feature_usage_test.go | 1 + 19 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 autoid_service/autoid_test.go diff --git a/autoid_service/BUILD.bazel b/autoid_service/BUILD.bazel index 6f1a13742ca80..26eb992c89474 100644 --- a/autoid_service/BUILD.bazel +++ b/autoid_service/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "autoid_service", @@ -9,6 +9,7 @@ go_library( "//config", "//kv", "//meta", + "//meta/autoid", "//metrics", "//owner", "//parser/model", @@ -23,3 +24,18 @@ go_library( "@org_uber_go_zap//:zap", ], ) + +go_test( + name = "autoid_service_test", + srcs = ["autoid_test.go"], + embed = [":autoid_service"], + deps = [ + "//parser/model", + "//testkit", + "@com_github_pingcap_kvproto//pkg/autoid", + "@com_github_stretchr_testify//require", + "@io_etcd_go_etcd_tests_v3//integration", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//credentials/insecure", + ], +) diff --git a/autoid_service/autoid.go b/autoid_service/autoid.go index 1a4d2b426263e..aa6c487cb0b48 100644 --- a/autoid_service/autoid.go +++ b/autoid_service/autoid.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" + autoid1 "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/owner" "github.com/pingcap/tidb/parser/model" @@ -253,6 +254,7 @@ type Service struct { func New(selfAddr string, etcdAddr []string, store kv.Storage, tlsConfig *tls.Config) *Service { cfg := config.GetGlobalConfig() etcdLogCfg := zap.NewProductionConfig() + cli, err := clientv3.New(clientv3.Config{ LogConfig: &etcdLogCfg, Endpoints: etcdAddr, @@ -270,9 +272,12 @@ func New(selfAddr string, etcdAddr []string, store kv.Storage, tlsConfig *tls.Co if err != nil { panic(err) } + return newWithCli(selfAddr, cli, store) +} +func newWithCli(selfAddr string, cli *clientv3.Client, store kv.Storage) *Service { l := owner.NewOwnerManager(context.Background(), cli, "autoid", selfAddr, autoIDLeaderPath) - err = l.CampaignOwner() + err := l.CampaignOwner() if err != nil { panic(err) } @@ -299,7 +304,7 @@ func (m *mockClient) Rebase(ctx context.Context, in *autoid.RebaseRequest, opts var global = make(map[string]*mockClient) // MockForTest is used for testing, the UT test and unistore use this. -func MockForTest(store kv.Storage) *mockClient { +func MockForTest(store kv.Storage) autoid.AutoIDAllocClient { uuid := store.UUID() ret, ok := global[uuid] if !ok { @@ -515,3 +520,7 @@ func (s *Service) Rebase(ctx context.Context, req *autoid.RebaseRequest) (*autoi } return &autoid.RebaseResponse{}, nil } + +func init() { + autoid1.MockForTest = MockForTest +} diff --git a/autoid_service/autoid_test.go b/autoid_service/autoid_test.go new file mode 100644 index 0000000000000..df2722309cf6e --- /dev/null +++ b/autoid_service/autoid_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autoid + +import ( + "context" + "fmt" + "math" + "net" + "testing" + "time" + + "github.com/pingcap/kvproto/pkg/autoid" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/tests/v3/integration" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type autoIDResp struct { + *autoid.AutoIDResponse + error + *testing.T +} + +func (resp autoIDResp) check(min, max int64) { + require.NoError(resp.T, resp.error) + require.Equal(resp.T, resp.AutoIDResponse, &autoid.AutoIDResponse{Min: min, Max: max}) +} + +func (resp autoIDResp) checkErrmsg() { + require.NoError(resp.T, resp.error) + require.True(resp.T, len(resp.GetErrmsg()) > 0) +} + +type rebaseResp struct { + *autoid.RebaseResponse + error + *testing.T +} + +func (resp rebaseResp) check(msg string) { + require.NoError(resp.T, resp.error) + require.Equal(resp.T, string(resp.RebaseResponse.GetErrmsg()), msg) +} + +func TestAPI(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + cli := MockForTest(store) + tk.MustExec("use test") + tk.MustExec("create table t (id int key auto_increment);") + is := dom.InfoSchema() + dbInfo, ok := is.SchemaByName(model.NewCIStr("test")) + require.True(t, ok) + + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + tbInfo := tbl.Meta() + + ctx := context.Background() + checkCurrValue := func(t *testing.T, cli autoid.AutoIDAllocClient, min, max int64) { + req := &autoid.AutoIDRequest{DbID: dbInfo.ID, TblID: tbInfo.ID, N: 0} + resp, err := cli.AllocAutoID(ctx, req) + require.NoError(t, err) + require.Equal(t, resp, &autoid.AutoIDResponse{Min: min, Max: max}) + } + autoIDRequest := func(t *testing.T, cli autoid.AutoIDAllocClient, unsigned bool, n uint64, more ...int64) autoIDResp { + increment := int64(1) + offset := int64(1) + if len(more) >= 1 { + increment = more[0] + } + if len(more) >= 2 { + offset = more[1] + } + req := &autoid.AutoIDRequest{DbID: dbInfo.ID, TblID: tbInfo.ID, IsUnsigned: unsigned, N: n, Increment: increment, Offset: offset} + resp, err := cli.AllocAutoID(ctx, req) + return autoIDResp{resp, err, t} + } + rebaseRequest := func(t *testing.T, cli autoid.AutoIDAllocClient, unsigned bool, n int64, force ...struct{}) rebaseResp { + req := &autoid.RebaseRequest{ + DbID: dbInfo.ID, + TblID: tbInfo.ID, + Base: n, + IsUnsigned: unsigned, + Force: len(force) > 0, + } + resp, err := cli.Rebase(ctx, req) + return rebaseResp{resp, err, t} + } + var force = struct{}{} + + // basic auto id operation + autoIDRequest(t, cli, false, 1).check(0, 1) + autoIDRequest(t, cli, false, 10).check(1, 11) + checkCurrValue(t, cli, 11, 11) + autoIDRequest(t, cli, false, 128).check(11, 139) + autoIDRequest(t, cli, false, 1, 10, 5).check(139, 145) + + // basic rebase operation + rebaseRequest(t, cli, false, 666).check("") + autoIDRequest(t, cli, false, 1).check(666, 667) + + rebaseRequest(t, cli, false, 6666).check("") + autoIDRequest(t, cli, false, 1).check(6666, 6667) + + // rebase will not decrease the value without 'force' + rebaseRequest(t, cli, false, 44).check("") + checkCurrValue(t, cli, 6667, 6667) + rebaseRequest(t, cli, false, 44, force).check("") + checkCurrValue(t, cli, 44, 44) + + // max increase 1 + rebaseRequest(t, cli, false, math.MaxInt64, force).check("") + checkCurrValue(t, cli, math.MaxInt64, math.MaxInt64) + autoIDRequest(t, cli, false, 1).checkErrmsg() + + rebaseRequest(t, cli, true, 0, force).check("") + checkCurrValue(t, cli, 0, 0) + autoIDRequest(t, cli, true, 1).check(0, 1) + autoIDRequest(t, cli, true, 10).check(1, 11) + autoIDRequest(t, cli, true, 128).check(11, 139) + autoIDRequest(t, cli, true, 1, 10, 5).check(139, 145) + + // max increase 1 + rebaseRequest(t, cli, true, math.MaxInt64).check("") + checkCurrValue(t, cli, math.MaxInt64, math.MaxInt64) + autoIDRequest(t, cli, true, 1).check(math.MaxInt64, math.MinInt64) + autoIDRequest(t, cli, true, 1).check(math.MinInt64, math.MinInt64+1) + + rebaseRequest(t, cli, true, -1).check("") + checkCurrValue(t, cli, -1, -1) + autoIDRequest(t, cli, true, 1).check(-1, 0) +} + +func TestGRPC(t *testing.T) { + integration.BeforeTestExternal(t) + store := testkit.CreateMockStore(t) + cluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer cluster.Terminate(t) + etcdCli := cluster.RandClient() + + var addr string + var listener net.Listener + for port := 10080; ; port++ { + var err error + addr = fmt.Sprintf("127.0.0.1:%d", port) + listener, err = net.Listen("tcp", addr) + if err == nil { + break + } + } + defer listener.Close() + + service := newWithCli(addr, etcdCli, store) + defer service.Close() + + var i int + for !service.leaderShip.IsOwner() { + time.Sleep(100 * time.Millisecond) + i++ + if i >= 20 { + break + } + } + require.Less(t, i, 20) + + grpcServer := grpc.NewServer() + autoid.RegisterAutoIDAllocServer(grpcServer, service) + go func() { + grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + grpcConn, err := grpc.Dial("127.0.0.1:10080", grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + cli := autoid.NewAutoIDAllocClient(grpcConn) + _, err = cli.AllocAutoID(context.Background(), &autoid.AutoIDRequest{ + DbID: 0, + TblID: 0, + N: 1, + Increment: 1, + Offset: 1, + IsUnsigned: false, + }) + require.NoError(t, err) +} diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel index dc179250ad4bd..dce469fe15321 100644 --- a/ddl/BUILD.bazel +++ b/ddl/BUILD.bazel @@ -202,6 +202,7 @@ go_test( flaky = True, shard_count = 50, deps = [ + "//autoid_service", "//config", "//ddl/ingest", "//ddl/placement", diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index cc9cc657fdc6f..1d482f8cecada 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/pingcap/errors" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/ddl/schematracker" diff --git a/executor/autoidtest/BUILD.bazel b/executor/autoidtest/BUILD.bazel index 0f5bf6c434f91..cd04b266fa2e3 100644 --- a/executor/autoidtest/BUILD.bazel +++ b/executor/autoidtest/BUILD.bazel @@ -9,6 +9,7 @@ go_test( flaky = True, race = "on", deps = [ + "//autoid_service", "//config", "//ddl/testutil", "//meta/autoid", diff --git a/executor/autoidtest/autoid_test.go b/executor/autoidtest/autoid_test.go index 7823a7488bf98..eb8cc3f874159 100644 --- a/executor/autoidtest/autoid_test.go +++ b/executor/autoidtest/autoid_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/pingcap/failpoint" + _ "github.com/pingcap/tidb/autoid_service" ddltestutil "github.com/pingcap/tidb/ddl/testutil" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/session" diff --git a/executor/issuetest/BUILD.bazel b/executor/issuetest/BUILD.bazel index 77bfaf7f11290..8d930738d2b7c 100644 --- a/executor/issuetest/BUILD.bazel +++ b/executor/issuetest/BUILD.bazel @@ -9,6 +9,7 @@ go_test( flaky = True, shard_count = 50, deps = [ + "//autoid_service", "//config", "//kv", "//meta/autoid", diff --git a/executor/issuetest/executor_issue_test.go b/executor/issuetest/executor_issue_test.go index 9e28feca1530f..8dcbf251cdf89 100644 --- a/executor/issuetest/executor_issue_test.go +++ b/executor/issuetest/executor_issue_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/pingcap/failpoint" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/auth" diff --git a/executor/showtest/BUILD.bazel b/executor/showtest/BUILD.bazel index 807e00c8e88ec..1882c92e0627d 100644 --- a/executor/showtest/BUILD.bazel +++ b/executor/showtest/BUILD.bazel @@ -11,6 +11,7 @@ go_test( race = "on", shard_count = 45, deps = [ + "//autoid_service", "//config", "//executor", "//infoschema", diff --git a/executor/showtest/show_test.go b/executor/showtest/show_test.go index 0573de30137f6..c327eb474e6d6 100644 --- a/executor/showtest/show_test.go +++ b/executor/showtest/show_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/pingcap/failpoint" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser/auth" diff --git a/meta/autoid/BUILD.bazel b/meta/autoid/BUILD.bazel index 50e53258f305b..b67f7f7c223c7 100644 --- a/meta/autoid/BUILD.bazel +++ b/meta/autoid/BUILD.bazel @@ -11,7 +11,6 @@ go_library( importpath = "github.com/pingcap/tidb/meta/autoid", visibility = ["//visibility:public"], deps = [ - "//autoid_service", "//config", "//errno", "//kv", diff --git a/meta/autoid/autoid.go b/meta/autoid/autoid.go index aba2ad565b617..1f5ffeb2fd094 100644 --- a/meta/autoid/autoid.go +++ b/meta/autoid/autoid.go @@ -26,7 +26,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" - autoid "github.com/pingcap/tidb/autoid_service" + "github.com/pingcap/kvproto/pkg/autoid" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/metrics" @@ -558,6 +558,11 @@ func NextStep(curStep int64, consumeDur time.Duration) int64 { return res } +// MockForTest is exported for testing. +// The actual implementation is in github.com/pingcap/tidb/autoid_service because of the +// package circle depending issue. +var MockForTest func(kv.Storage) autoid.AutoIDAllocClient + func newSinglePointAlloc(store kv.Storage, dbID, tblID int64, isUnsigned bool) *singlePointAlloc { ebd, ok := store.(kv.EtcdBackend) if !ok { @@ -587,7 +592,7 @@ func newSinglePointAlloc(store kv.Storage, dbID, tblID int64, isUnsigned bool) * spa.clientDiscover = clientDiscover{etcdCli: etcdCli} } else { spa.clientDiscover = clientDiscover{} - spa.mu.AutoIDAllocClient = autoid.MockForTest(store) + spa.mu.AutoIDAllocClient = MockForTest(store) } // mockAutoIDChange failpoint is not implemented in this allocator, so fallback to use the default one. diff --git a/session/BUILD.bazel b/session/BUILD.bazel index dc3106abdfe63..63118a3ea701a 100644 --- a/session/BUILD.bazel +++ b/session/BUILD.bazel @@ -128,6 +128,7 @@ go_test( race = "on", shard_count = 50, deps = [ + "//autoid_service", "//bindinfo", "//config", "//ddl", diff --git a/session/bench_test.go b/session/bench_test.go index ece43c39cdc77..04c86b9227f8d 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/pingcap/log" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor" diff --git a/sessionctx/binloginfo/BUILD.bazel b/sessionctx/binloginfo/BUILD.bazel index 6d5a600b9e68c..7a843495273ea 100644 --- a/sessionctx/binloginfo/BUILD.bazel +++ b/sessionctx/binloginfo/BUILD.bazel @@ -33,6 +33,7 @@ go_test( embed = [":binloginfo"], flaky = True, deps = [ + "//autoid_service", "//ddl", "//domain", "//kv", diff --git a/sessionctx/binloginfo/binloginfo_test.go b/sessionctx/binloginfo/binloginfo_test.go index 3c777a9436234..28235b5184b68 100644 --- a/sessionctx/binloginfo/binloginfo_test.go +++ b/sessionctx/binloginfo/binloginfo_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" diff --git a/telemetry/BUILD.bazel b/telemetry/BUILD.bazel index df7de986555a6..1f032aa3f237a 100644 --- a/telemetry/BUILD.bazel +++ b/telemetry/BUILD.bazel @@ -61,6 +61,7 @@ go_test( embed = [":telemetry"], flaky = True, deps = [ + "//autoid_service", "//config", "//ddl", "//domain", diff --git a/telemetry/data_feature_usage_test.go b/telemetry/data_feature_usage_test.go index cb3272d110b29..c303c53f3006b 100644 --- a/telemetry/data_feature_usage_test.go +++ b/telemetry/data_feature_usage_test.go @@ -18,6 +18,7 @@ import ( "fmt" "testing" + _ "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/sessionctx/variable" From 3ccff46aa3bc2dfabfda69ef8d718410af29522a Mon Sep 17 00:00:00 2001 From: Jk Xu <54522439+Dousir9@users.noreply.github.com> Date: Tue, 3 Jan 2023 21:28:20 +0800 Subject: [PATCH 7/9] executor: special handling is required when an "auto id out of range" error occurs in `insert ignore into ... on on duplicate ...` (#39847) close pingcap/tidb#38950 --- executor/insert_common.go | 12 +++++++++++- executor/writetest/write_test.go | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/executor/insert_common.go b/executor/insert_common.go index 8440242f1dad5..862c82a88da90 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" @@ -771,7 +772,16 @@ func setDatumAutoIDAndCast(ctx sessionctx.Context, d *types.Datum, id int64, col var err error *d, err = table.CastValue(ctx, *d, col.ToInfo(), false, false) if err == nil && d.GetInt64() < id { - // Auto ID is out of range, the truncated ID is possible to duplicate with an existing ID. + // Auto ID is out of range. + sc := ctx.GetSessionVars().StmtCtx + insertPlan, ok := sc.GetPlan().(*core.Insert) + if ok && sc.TruncateAsWarning && len(insertPlan.OnDuplicate) > 0 { + // Fix issue #38950: AUTO_INCREMENT is incompatible with mysql + // An auto id out of range error occurs in `insert ignore into ... on duplicate ...`. + // We should allow the SQL to be executed successfully. + return nil + } + // The truncated ID is possible to duplicate with an existing ID. // To prevent updating unrelated rows in the REPLACE statement, it is better to throw an error. return autoid.ErrAutoincReadFailed } diff --git a/executor/writetest/write_test.go b/executor/writetest/write_test.go index ebeaaa388e269..32939a1b16033 100644 --- a/executor/writetest/write_test.go +++ b/executor/writetest/write_test.go @@ -592,6 +592,25 @@ commit;` tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1526 Table has no partition for value 3")) } +func TestIssue38950(t *testing.T) { + store := testkit.CreateMockStore(t) + var cfg kv.InjectionConfig + tk := testkit.NewTestKit(t, kv.NewInjectedStore(store, &cfg)) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t; create table t (id smallint auto_increment primary key);") + tk.MustExec("alter table t add column c1 int default 1;") + tk.MustExec("insert ignore into t(id) values (194626268);") + require.Empty(t, tk.Session().LastMessage()) + + tk.MustQuery("select * from t").Check(testkit.Rows("32767 1")) + + tk.MustExec("insert ignore into t(id) values ('*') on duplicate key update c1 = 2;") + require.Equal(t, int64(2), int64(tk.Session().AffectedRows())) + require.Empty(t, tk.Session().LastMessage()) + + tk.MustQuery("select * from t").Check(testkit.Rows("32767 2")) +} + func TestInsertOnDup(t *testing.T) { store := testkit.CreateMockStore(t) var cfg kv.InjectionConfig From 3e65e9b5c77aded6ebd7c1db65aa482bf5915b94 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Wed, 4 Jan 2023 10:34:20 +0800 Subject: [PATCH 8/9] util: gorotinue pool (#39872) close pingcap/tidb#38039 --- resourcemanager/pooltask/BUILD.bazel | 8 + resourcemanager/pooltask/task.go | 132 +++++++ util/gpool/BUILD.bazel | 11 + util/gpool/gpool.go | 69 ++++ util/gpool/spinlock.go | 47 +++ util/gpool/spmc/BUILD.bazel | 43 +++ util/gpool/spmc/main_test.go | 27 ++ util/gpool/spmc/option.go | 138 +++++++ util/gpool/spmc/spmcpool.go | 420 +++++++++++++++++++++ util/gpool/spmc/spmcpool_benchmark_test.go | 111 ++++++ util/gpool/spmc/spmcpool_test.go | 283 ++++++++++++++ util/gpool/spmc/worker.go | 74 ++++ util/gpool/spmc/worker_loop_queue.go | 192 ++++++++++ util/gpool/spmc/worker_loop_queue_test.go | 184 +++++++++ 14 files changed, 1739 insertions(+) create mode 100644 resourcemanager/pooltask/BUILD.bazel create mode 100644 resourcemanager/pooltask/task.go create mode 100644 util/gpool/BUILD.bazel create mode 100644 util/gpool/gpool.go create mode 100644 util/gpool/spinlock.go create mode 100644 util/gpool/spmc/BUILD.bazel create mode 100644 util/gpool/spmc/main_test.go create mode 100644 util/gpool/spmc/option.go create mode 100644 util/gpool/spmc/spmcpool.go create mode 100644 util/gpool/spmc/spmcpool_benchmark_test.go create mode 100644 util/gpool/spmc/spmcpool_test.go create mode 100644 util/gpool/spmc/worker.go create mode 100644 util/gpool/spmc/worker_loop_queue.go create mode 100644 util/gpool/spmc/worker_loop_queue_test.go diff --git a/resourcemanager/pooltask/BUILD.bazel b/resourcemanager/pooltask/BUILD.bazel new file mode 100644 index 0000000000000..c9e37436562ee --- /dev/null +++ b/resourcemanager/pooltask/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "pooltask", + srcs = ["task.go"], + importpath = "github.com/pingcap/tidb/resourcemanager/pooltask", + visibility = ["//visibility:public"], +) diff --git a/resourcemanager/pooltask/task.go b/resourcemanager/pooltask/task.go new file mode 100644 index 0000000000000..ef9b046c8ccba --- /dev/null +++ b/resourcemanager/pooltask/task.go @@ -0,0 +1,132 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pooltask + +import ( + "sync" +) + +// Context is a interface that can be used to create a context. +type Context[T any] interface { + GetContext() T +} + +// NilContext is to create a nil as context +type NilContext struct{} + +// GetContext is to get a nil as context +func (NilContext) GetContext() any { + return nil +} + +// TaskBox is a box which contains all info about pooltask. +type TaskBox[T any, U any, C any, CT any, TF Context[CT]] struct { + constArgs C + contextFunc TF + wg *sync.WaitGroup + task chan Task[T] + resultCh chan U + taskID uint64 +} + +// NewTaskBox is to create a task box for pool. +func NewTaskBox[T any, U any, C any, CT any, TF Context[CT]](constArgs C, contextFunc TF, wg *sync.WaitGroup, taskCh chan Task[T], resultCh chan U, taskID uint64) TaskBox[T, U, C, CT, TF] { + return TaskBox[T, U, C, CT, TF]{ + constArgs: constArgs, + contextFunc: contextFunc, + wg: wg, + task: taskCh, + resultCh: resultCh, + taskID: taskID, + } +} + +// TaskID is to get the task id. +func (t TaskBox[T, U, C, CT, TF]) TaskID() uint64 { + return t.taskID +} + +// ConstArgs is to get the const args. +func (t *TaskBox[T, U, C, CT, TF]) ConstArgs() C { + return t.constArgs +} + +// GetTaskCh is to get the task channel. +func (t *TaskBox[T, U, C, CT, TF]) GetTaskCh() chan Task[T] { + return t.task +} + +// GetResultCh is to get result channel +func (t *TaskBox[T, U, C, CT, TF]) GetResultCh() chan U { + return t.resultCh +} + +// GetContextFunc is to get context func. +func (t *TaskBox[T, U, C, CT, TF]) GetContextFunc() TF { + return t.contextFunc +} + +// Done is to set the pooltask status to complete. +func (t *TaskBox[T, U, C, CT, TF]) Done() { + t.wg.Done() +} + +// Clone is to copy the box +func (t *TaskBox[T, U, C, CT, TF]) Clone() *TaskBox[T, U, C, CT, TF] { + newBox := NewTaskBox[T, U, C, CT, TF](t.constArgs, t.contextFunc, t.wg, t.task, t.resultCh, t.taskID) + return &newBox +} + +// GPool is a goroutine pool. +type GPool[T any, U any, C any, CT any, TF Context[CT]] interface { + Tune(size int) +} + +// TaskController is a controller that can control or watch the pool. +type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct { + pool GPool[T, U, C, CT, TF] + close chan struct{} + wg *sync.WaitGroup + taskID uint64 + resultCh chan U +} + +// NewTaskController create a controller to deal with pooltask's status. +func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, closeCh chan struct{}, wg *sync.WaitGroup, resultCh chan U) TaskController[T, U, C, CT, TF] { + return TaskController[T, U, C, CT, TF]{ + pool: p, + taskID: taskID, + close: closeCh, + wg: wg, + resultCh: resultCh, + } +} + +// Wait is to wait the pool task to stop. +func (t *TaskController[T, U, C, CT, TF]) Wait() { + <-t.close + t.wg.Wait() + close(t.resultCh) +} + +// TaskID is to get the task id. +func (t *TaskController[T, U, C, CT, TF]) TaskID() uint64 { + return t.taskID +} + +// Task is a task that can be executed. +type Task[T any] struct { + Task T +} diff --git a/util/gpool/BUILD.bazel b/util/gpool/BUILD.bazel new file mode 100644 index 0000000000000..04a3dc25e7cd0 --- /dev/null +++ b/util/gpool/BUILD.bazel @@ -0,0 +1,11 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "gpool", + srcs = [ + "gpool.go", + "spinlock.go", + ], + importpath = "github.com/pingcap/tidb/util/gpool", + visibility = ["//visibility:public"], +) diff --git a/util/gpool/gpool.go b/util/gpool/gpool.go new file mode 100644 index 0000000000000..7611d29542a31 --- /dev/null +++ b/util/gpool/gpool.go @@ -0,0 +1,69 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gpool + +import ( + "errors" + "sync/atomic" + "time" +) + +const ( + // DefaultCleanIntervalTime is the interval time to clean up goroutines. + DefaultCleanIntervalTime = 5 * time.Second + + // OPENED represents that the pool is opened. + OPENED = iota + + // CLOSED represents that the pool is closed. + CLOSED +) + +var ( + // ErrPoolClosed will be returned when submitting task to a closed pool. + ErrPoolClosed = errors.New("this pool has been closed") + + // ErrPoolOverload will be returned when the pool is full and no workers available. + ErrPoolOverload = errors.New("too many goroutines blocked on submit or Nonblocking is set") + + // ErrProducerClosed will be returned when the producer is closed. + ErrProducerClosed = errors.New("this producer has been closed") +) + +// BasePool is base class of pool +type BasePool struct { + name string + generator atomic.Uint64 +} + +// NewBasePool is to create a new BasePool. +func NewBasePool() BasePool { + return BasePool{} +} + +// SetName is to set name. +func (p *BasePool) SetName(name string) { + p.name = name +} + +// Name is to get name. +func (p *BasePool) Name() string { + return p.name +} + +// NewTaskID is to get a new task ID. +func (p *BasePool) NewTaskID() uint64 { + return p.generator.Add(1) +} diff --git a/util/gpool/spinlock.go b/util/gpool/spinlock.go new file mode 100644 index 0000000000000..acf7d15192416 --- /dev/null +++ b/util/gpool/spinlock.go @@ -0,0 +1,47 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gpool + +import ( + "runtime" + "sync" + "sync/atomic" +) + +type spinLock uint32 + +const maxBackoff = 16 + +func (sl *spinLock) Lock() { + backoff := 1 + for !atomic.CompareAndSwapUint32((*uint32)(sl), 0, 1) { + // Leverage the exponential backoff algorithm, see https://en.wikipedia.org/wiki/Exponential_backoff. + for i := 0; i < backoff; i++ { + runtime.Gosched() + } + if backoff < maxBackoff { + backoff <<= 1 + } + } +} + +func (sl *spinLock) Unlock() { + atomic.StoreUint32((*uint32)(sl), 0) +} + +// NewSpinLock instantiates a spin-lock. +func NewSpinLock() sync.Locker { + return new(spinLock) +} diff --git a/util/gpool/spmc/BUILD.bazel b/util/gpool/spmc/BUILD.bazel new file mode 100644 index 0000000000000..db48d9771cb17 --- /dev/null +++ b/util/gpool/spmc/BUILD.bazel @@ -0,0 +1,43 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "spmc", + srcs = [ + "option.go", + "spmcpool.go", + "worker.go", + "worker_loop_queue.go", + ], + importpath = "github.com/pingcap/tidb/util/gpool/spmc", + visibility = ["//visibility:public"], + deps = [ + "//resourcemanager/pooltask", + "//util/gpool", + "//util/logutil", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_log//:log", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "spmc_test", + srcs = [ + "main_test.go", + "spmcpool_benchmark_test.go", + "spmcpool_test.go", + "worker_loop_queue_test.go", + ], + embed = [":spmc"], + race = "on", + deps = [ + "//resourcemanager/pooltask", + "//testkit/testsetup", + "//util", + "//util/gpool", + "@com_github_stretchr_testify//require", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/util/gpool/spmc/main_test.go b/util/gpool/spmc/main_test.go new file mode 100644 index 0000000000000..381e5302598d5 --- /dev/null +++ b/util/gpool/spmc/main_test.go @@ -0,0 +1,27 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + + "github.com/pingcap/tidb/testkit/testsetup" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testsetup.SetupForCommonTest() + goleak.VerifyTestMain(m) +} diff --git a/util/gpool/spmc/option.go b/util/gpool/spmc/option.go new file mode 100644 index 0000000000000..e317ce157b93d --- /dev/null +++ b/util/gpool/spmc/option.go @@ -0,0 +1,138 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "time" +) + +// Option represents the optional function. +type Option func(opts *Options) + +func loadOptions(options ...Option) *Options { + opts := DefaultOption() + for _, option := range options { + option(opts) + } + return opts +} + +// Options contains all options which will be applied when instantiating an pool. +type Options struct { + // PanicHandler is used to handle panics from each worker goroutine. + // if nil, panics will be thrown out again from worker goroutines. + PanicHandler func(interface{}) + + // ExpiryDuration is a period for the scavenger goroutine to clean up those expired workers, + // the scavenger scans all workers every `ExpiryDuration` and clean up those workers that haven't been + // used for more than `ExpiryDuration`. + ExpiryDuration time.Duration + + // LimitDuration is a period in the limit mode. + LimitDuration time.Duration + + // Max number of goroutine blocking on pool.Submit. + // 0 (default value) means no such limit. + MaxBlockingTasks int + + // When Nonblocking is true, Pool.AddProduce will never be blocked. + // ErrPoolOverload will be returned when Pool.Submit cannot be done at once. + // When Nonblocking is true, MaxBlockingTasks is inoperative. + Nonblocking bool +} + +// DefaultOption is the default option. +func DefaultOption() *Options { + return &Options{ + LimitDuration: 200 * time.Millisecond, + Nonblocking: true, + } +} + +// WithExpiryDuration sets up the interval time of cleaning up goroutines. +func WithExpiryDuration(expiryDuration time.Duration) Option { + return func(opts *Options) { + opts.ExpiryDuration = expiryDuration + } +} + +// WithMaxBlockingTasks sets up the maximum number of goroutines that are blocked when it reaches the capacity of pool. +func WithMaxBlockingTasks(maxBlockingTasks int) Option { + return func(opts *Options) { + opts.MaxBlockingTasks = maxBlockingTasks + } +} + +// WithNonblocking indicates that pool will return nil when there is no available workers. +func WithNonblocking(nonblocking bool) Option { + return func(opts *Options) { + opts.Nonblocking = nonblocking + } +} + +// WithPanicHandler sets up panic handler. +func WithPanicHandler(panicHandler func(interface{})) Option { + return func(opts *Options) { + opts.PanicHandler = panicHandler + } +} + +// TaskOption represents the optional function. +type TaskOption func(opts *TaskOptions) + +func loadTaskOptions(options ...TaskOption) *TaskOptions { + opts := new(TaskOptions) + for _, option := range options { + option(opts) + } + if opts.Concurrency == 0 { + opts.Concurrency = 1 + } + if opts.ResultChanLen == 0 { + opts.ResultChanLen = uint64(opts.Concurrency) + } + if opts.ResultChanLen == 0 { + opts.ResultChanLen = uint64(opts.Concurrency) + } + return opts +} + +// TaskOptions contains all options +type TaskOptions struct { + Concurrency int + ResultChanLen uint64 + TaskChanLen uint64 +} + +// WithResultChanLen is to set the length of result channel. +func WithResultChanLen(resultChanLen uint64) TaskOption { + return func(opts *TaskOptions) { + opts.ResultChanLen = resultChanLen + } +} + +// WithTaskChanLen is to set the length of task channel. +func WithTaskChanLen(taskChanLen uint64) TaskOption { + return func(opts *TaskOptions) { + opts.TaskChanLen = taskChanLen + } +} + +// WithConcurrency is to set the concurrency of task. +func WithConcurrency(c int) TaskOption { + return func(opts *TaskOptions) { + opts.Concurrency = c + } +} diff --git a/util/gpool/spmc/spmcpool.go b/util/gpool/spmc/spmcpool.go new file mode 100644 index 0000000000000..b69c7a05e0eca --- /dev/null +++ b/util/gpool/spmc/spmcpool.go @@ -0,0 +1,420 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util/gpool" + "github.com/pingcap/tidb/util/logutil" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// Pool is a single producer, multiple consumer goroutine pool. +// T is the type of the task. We can treat it as input. +// U is the type of the result. We can treat it as output. +// C is the type of the const parameter. if Our task look like y = ax + b, C acts like b as const parameter. +// CT is the type of the context. It needs to be read/written parallel. +// TF is the type of the context getter. It is used to get a context. +// if we don't need to use CT/TF, we can define CT as any and TF as NilContext. +type Pool[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + gpool.BasePool + workerCache sync.Pool + workers *loopQueue[T, U, C, CT, TF] + lock sync.Locker + cond *sync.Cond + taskCh chan *pooltask.TaskBox[T, U, C, CT, TF] + options *Options + stopCh chan struct{} + consumerFunc func(T, C, CT) U + capacity atomic.Int32 + running atomic.Int32 + state atomic.Int32 + waiting atomic.Int32 // waiting is the number of goroutines that are waiting for the pool to be available. + heartbeatDone atomic.Bool + + waitingTask atomicutil.Uint32 // waitingTask is the number of tasks that are waiting for the pool to be available. +} + +// NewSPMCPool create a single producer, multiple consumer goroutine pool. +func NewSPMCPool[T any, U any, C any, CT any, TF pooltask.Context[CT]](name string, size int32, options ...Option) (*Pool[T, U, C, CT, TF], error) { + opts := loadOptions(options...) + if expiry := opts.ExpiryDuration; expiry <= 0 { + opts.ExpiryDuration = gpool.DefaultCleanIntervalTime + } + result := &Pool[T, U, C, CT, TF]{ + BasePool: gpool.NewBasePool(), + taskCh: make(chan *pooltask.TaskBox[T, U, C, CT, TF], 128), + stopCh: make(chan struct{}), + lock: gpool.NewSpinLock(), + options: opts, + } + result.SetName(name) + result.state.Store(int32(gpool.OPENED)) + result.workerCache.New = func() interface{} { + return &goWorker[T, U, C, CT, TF]{ + pool: result, + } + } + result.capacity.Add(size) + result.workers = newWorkerLoopQueue[T, U, C, CT, TF](int(size)) + result.cond = sync.NewCond(result.lock) + // Start a goroutine to clean up expired workers periodically. + go result.purgePeriodically() + return result, nil +} + +// purgePeriodically clears expired workers periodically which runs in an individual goroutine, as a scavenger. +func (p *Pool[T, U, C, CT, TF]) purgePeriodically() { + heartbeat := time.NewTicker(p.options.ExpiryDuration) + defer func() { + heartbeat.Stop() + p.heartbeatDone.Store(true) + }() + for { + select { + case <-heartbeat.C: + case <-p.stopCh: + return + } + + if p.IsClosed() { + break + } + + p.lock.Lock() + expiredWorkers := p.workers.retrieveExpiry(p.options.ExpiryDuration) + p.lock.Unlock() + + // Notify obsolete workers to stop. + // This notification must be outside the p.lock, since w.task + // may be blocking and may consume a lot of time if many workers + // are located on non-local CPUs. + for i := range expiredWorkers { + expiredWorkers[i].taskBoxCh <- nil + expiredWorkers[i] = nil + } + + // There might be a situation where all workers have been cleaned up(no worker is running), + // or another case where the pool capacity has been Tuned up, + // while some invokers still get stuck in "p.cond.Wait()", + // then it ought to wake all those invokers. + if p.Running() == 0 || (p.Waiting() > 0 && p.Free() > 0) || p.waitingTask.Load() > 0 { + p.cond.Broadcast() + } + } +} + +// Tune changes the capacity of this pool, note that it is noneffective to the infinite or pre-allocation pool. +func (p *Pool[T, U, C, CT, TF]) Tune(size int) { + capacity := p.Cap() + if capacity == -1 || size <= 0 || size == capacity { + return + } + p.capacity.Store(int32(size)) + if size > capacity { + // boost + if size-capacity == 1 { + p.cond.Signal() + return + } + p.cond.Broadcast() + } +} + +// Running returns the number of workers currently running. +func (p *Pool[T, U, C, CT, TF]) Running() int { + return int(p.running.Load()) +} + +// Free returns the number of available goroutines to work, -1 indicates this pool is unlimited. +func (p *Pool[T, U, C, CT, TF]) Free() int { + c := p.Cap() + if c < 0 { + return -1 + } + return c - p.Running() +} + +// Waiting returns the number of tasks which are waiting be executed. +func (p *Pool[T, U, C, CT, TF]) Waiting() int { + return int(p.waiting.Load()) +} + +// IsClosed indicates whether the pool is closed. +func (p *Pool[T, U, C, CT, TF]) IsClosed() bool { + return p.state.Load() == gpool.CLOSED +} + +// Cap returns the capacity of this pool. +func (p *Pool[T, U, C, CT, TF]) Cap() int { + return int(p.capacity.Load()) +} + +func (p *Pool[T, U, C, CT, TF]) addRunning(delta int) { + p.running.Add(int32(delta)) +} + +func (p *Pool[T, U, C, CT, TF]) addWaiting(delta int) { + p.waiting.Add(int32(delta)) +} + +func (p *Pool[T, U, C, CT, TF]) addWaitingTask() { + p.waitingTask.Inc() +} + +func (p *Pool[T, U, C, CT, TF]) subWaitingTask() { + p.waitingTask.Dec() +} + +// release closes this pool and releases the worker queue. +func (p *Pool[T, U, C, CT, TF]) release() { + if !p.state.CompareAndSwap(gpool.OPENED, gpool.CLOSED) { + return + } + p.lock.Lock() + p.workers.reset() + p.lock.Unlock() + // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent + // those callers blocking infinitely. + p.cond.Broadcast() + close(p.taskCh) +} + +func isClose(exitCh chan struct{}) bool { + select { + case <-exitCh: + return true + default: + } + return false +} + +// ReleaseAndWait is like Release, it waits all workers to exit. +func (p *Pool[T, U, C, CT, TF]) ReleaseAndWait() { + if p.IsClosed() || isClose(p.stopCh) { + return + } + + close(p.stopCh) + p.release() + for { + // Wait for all workers to exit and all task to be completed. + if p.Running() == 0 && p.heartbeatDone.Load() && p.waitingTask.Load() == 0 { + return + } + } +} + +// SetConsumerFunc is to set ConsumerFunc which is to process the task. +func (p *Pool[T, U, C, CT, TF]) SetConsumerFunc(consumerFunc func(T, C, CT) U) { + p.consumerFunc = consumerFunc +} + +// AddProduceBySlice is to add Produce by a slice. +// Producer need to return ErrProducerClosed when to exit. +func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) { + opt := loadTaskOptions(options...) + taskID := p.NewTaskID() + var wg sync.WaitGroup + result := make(chan U, opt.ResultChanLen) + closeCh := make(chan struct{}) + inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + for i := 0; i < opt.Concurrency; i++ { + err := p.run() + if err == gpool.ErrPoolClosed { + break + } + taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) + p.addWaitingTask() + p.taskCh <- &taskBox + } + go func() { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) + } + close(closeCh) + close(inputCh) + }() + for { + tasks, err := producer() + if err != nil { + if errors.Is(err, gpool.ErrProducerClosed) { + return + } + log.Error("producer error", zap.Error(err)) + return + } + for _, task := range tasks { + wg.Add(1) + task := pooltask.Task[T]{ + Task: task, + } + inputCh <- task + } + } + }() + return result, tc +} + +// AddProducer is to add producer. +// Producer need to return ErrProducerClosed when to exit. +func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) { + opt := loadTaskOptions(options...) + taskID := p.NewTaskID() + var wg sync.WaitGroup + result := make(chan U, opt.ResultChanLen) + closeCh := make(chan struct{}) + inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + for i := 0; i < opt.Concurrency; i++ { + err := p.run() + if err == gpool.ErrPoolClosed { + break + } + p.addWaitingTask() + taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) + p.taskCh <- &taskBox + } + go func() { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) + } + close(closeCh) + close(inputCh) + }() + for { + task, err := producer() + if err != nil { + if errors.Is(err, gpool.ErrProducerClosed) { + return + } + log.Error("producer error", zap.Error(err)) + return + } + wg.Add(1) + t := pooltask.Task[T]{ + Task: task, + } + inputCh <- t + } + }() + return result, tc +} + +func (p *Pool[T, U, C, CT, TF]) run() error { + if p.IsClosed() { + return gpool.ErrPoolClosed + } + var w *goWorker[T, U, C, CT, TF] + if w = p.retrieveWorker(); w == nil { + return gpool.ErrPoolOverload + } + return nil +} + +// retrieveWorker returns an available worker to run the tasks. +func (p *Pool[T, U, C, CT, TF]) retrieveWorker() (w *goWorker[T, U, C, CT, TF]) { + spawnWorker := func() { + w = p.workerCache.Get().(*goWorker[T, U, C, CT, TF]) + w.taskBoxCh = p.taskCh + w.run() + } + + p.lock.Lock() + + w = p.workers.detach() + if w != nil { // first try to fetch the worker from the queue + p.lock.Unlock() + } else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() { + // if the worker queue is empty and we don't run out of the pool capacity, + // then just spawn a new worker goroutine. + p.lock.Unlock() + spawnWorker() + } else { // otherwise, we'll have to keep them blocked and wait for at least one worker to be put back into pool. + if p.options.Nonblocking { + p.lock.Unlock() + return + } + retry: + if p.options.MaxBlockingTasks != 0 && p.Waiting() >= p.options.MaxBlockingTasks { + p.lock.Unlock() + return + } + p.addWaiting(1) + p.cond.Wait() // block and wait for an available worker + p.addWaiting(-1) + + if p.IsClosed() { + p.lock.Unlock() + return + } + + var nw int + if nw = p.Running(); nw == 0 { // awakened by the scavenger + p.lock.Unlock() + spawnWorker() + return + } + if w = p.workers.detach(); w == nil { + if nw < p.Cap() { + p.lock.Unlock() + spawnWorker() + return + } + goto retry + } + p.lock.Unlock() + } + return +} + +// revertWorker puts a worker back into free pool, recycling the goroutines. +func (p *Pool[T, U, C, CT, TF]) revertWorker(worker *goWorker[T, U, C, CT, TF]) bool { + if capacity := p.Cap(); capacity > 0 && p.Running() > capacity || p.IsClosed() { + p.cond.Broadcast() + return false + } + worker.recycleTime.Store(time.Now()) + p.lock.Lock() + + if p.IsClosed() { + p.lock.Unlock() + return false + } + + err := p.workers.insert(worker) + if err != nil { + p.lock.Unlock() + if err == errQueueIsFull && p.waitingTask.Load() > 0 { + return true + } + return false + } + + // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. + p.cond.Signal() + p.lock.Unlock() + return true +} diff --git a/util/gpool/spmc/spmcpool_benchmark_test.go b/util/gpool/spmc/spmcpool_benchmark_test.go new file mode 100644 index 0000000000000..db3a4f0824e78 --- /dev/null +++ b/util/gpool/spmc/spmcpool_benchmark_test.go @@ -0,0 +1,111 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/gpool" +) + +const ( + RunTimes = 10000 + DefaultExpiredTime = 10 * time.Second +) + +func BenchmarkGPool(b *testing.B) { + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("test", 10) + if err != nil { + b.Fatal(err) + } + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(sema) + }) + producerFunc := func() (struct{}, error) { + _, ok := <-sema + if ok { + return struct{}{}, nil + } + return struct{}{}, gpool.ErrProducerClosed + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(6), WithResultChanLen(10)) + exitCh := make(chan struct{}) + wg.Run(func() { + for { + select { + case <-resultCh: + case <-exitCh: + return + } + } + }) + ctl.Wait() + close(exitCh) + wg.Wait() + } +} + +func BenchmarkGoCommon(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + var wg util.WaitGroupWrapper + var wgp util.WaitGroupWrapper + sema := make(chan struct{}, 10) + result := make(chan struct{}, 10) + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(sema) + }) + + for n := 0; n < 6; n++ { + wg.Run(func() { + item, ok := <-sema + if !ok { + return + } + result <- item + }) + } + exitCh := make(chan struct{}) + wgp.Run(func() { + for { + select { + case <-result: + case <-exitCh: + return + } + } + }) + wg.Wait() + close(exitCh) + wgp.Wait() + } +} diff --git a/util/gpool/spmc/spmcpool_test.go b/util/gpool/spmc/spmcpool_test.go new file mode 100644 index 0000000000000..984f501789c47 --- /dev/null +++ b/util/gpool/spmc/spmcpool_test.go @@ -0,0 +1,283 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/gpool" + "github.com/stretchr/testify/require" +) + +func TestPool(t *testing.T) { + type ConstArgs struct { + a int + } + myArgs := ConstArgs{a: 10} + // init the pool + // input type, output type, constArgs type + pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestPool", 10) + require.NoError(t, err) + pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int { + return task + constArgs.a + }) + taskCh := make(chan int, 10) + for i := 1; i < 11; i++ { + taskCh <- i + } + pfunc := func() (int, error) { + select { + case task := <-taskCh: + return task, nil + default: + return 0, gpool.ErrProducerClosed + } + } + // add new task + resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4)) + + var count atomic.Uint32 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for result := range resultCh { + count.Add(1) + require.Greater(t, result, 10) + } + }() + // Waiting task finishing + control.Wait() + wg.Wait() + require.Equal(t, uint32(10), count.Load()) + // close pool + pool.ReleaseAndWait() +} + +func TestPoolWithEnoughCapacity(t *testing.T) { + const ( + RunTimes = 1000 + poolsize = 30 + concurrency = 6 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithEnoughCapa", poolsize, WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg util.WaitGroupWrapper + for i := 0; i < 3; i++ { + twg.Run(func() { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + exitCh := make(chan struct{}) + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }) + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }) + } + twg.Wait() +} + +func TestPoolWithoutEnoughCapacity(t *testing.T) { + const ( + RunTimes = 5 + concurrency = 2 + poolsize = 2 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithoutEnoughCapa", poolsize, + WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg sync.WaitGroup + for i := 0; i < 10; i++ { + func() { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }() + } + twg.Wait() +} + +func TestPoolWithoutEnoughCapacityParallel(t *testing.T) { + const ( + RunTimes = 5 + concurrency = 2 + poolsize = 2 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithoutEnoughCapa", poolsize, + WithExpiryDuration(DefaultExpiredTime), WithNonblocking(true)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg sync.WaitGroup + for i := 0; i < 10; i++ { + twg.Add(1) + go func() { + defer twg.Done() + sema := make(chan struct{}, 10) + var wg sync.WaitGroup + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }() + } + twg.Wait() +} + +func TestBenchPool(t *testing.T) { + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestBenchPool", 10, WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + + for i := 0; i < 1000; i++ { + sema := make(chan struct{}, 10) + var wg sync.WaitGroup + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(6)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + } + p.ReleaseAndWait() +} diff --git a/util/gpool/spmc/worker.go b/util/gpool/spmc/worker.go new file mode 100644 index 0000000000000..32ff56a790dbd --- /dev/null +++ b/util/gpool/spmc/worker.go @@ -0,0 +1,74 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "github.com/pingcap/log" + "github.com/pingcap/tidb/resourcemanager/pooltask" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// goWorker is the actual executor who runs the tasks, +// it starts a goroutine that accepts tasks and +// performs function calls. +type goWorker[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + // pool who owns this worker. + pool *Pool[T, U, C, CT, TF] + + // taskBoxCh is a job should be done. + taskBoxCh chan *pooltask.TaskBox[T, U, C, CT, TF] + + // recycleTime will be updated when putting a worker back into queue. + recycleTime atomicutil.Time +} + +// run starts a goroutine to repeat the process +// that performs the function calls. +func (w *goWorker[T, U, C, CT, TF]) run() { + w.pool.addRunning(1) + go func() { + defer func() { + w.pool.addRunning(-1) + w.pool.workerCache.Put(w) + if p := recover(); p != nil { + if ph := w.pool.options.PanicHandler; ph != nil { + ph(p) + } else { + log.Error("worker exits from a panic", zap.Any("recover", p), zap.Stack("stack")) + } + } + // Call Signal() here in case there are goroutines waiting for available workers. + w.pool.cond.Signal() + }() + + for f := range w.taskBoxCh { + if f == nil { + return + } + w.pool.subWaitingTask() + ctx := f.GetContextFunc().GetContext() + if f.GetResultCh() != nil { + for t := range f.GetTaskCh() { + f.GetResultCh() <- w.pool.consumerFunc(t.Task, f.ConstArgs(), ctx) + f.Done() + } + } + if ok := w.pool.revertWorker(w); !ok { + return + } + } + }() +} diff --git a/util/gpool/spmc/worker_loop_queue.go b/util/gpool/spmc/worker_loop_queue.go new file mode 100644 index 0000000000000..59c7b97fd425a --- /dev/null +++ b/util/gpool/spmc/worker_loop_queue.go @@ -0,0 +1,192 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/resourcemanager/pooltask" +) + +var ( + // errQueueIsFull will be returned when the worker queue is full. + errQueueIsFull = errors.New("the queue is full") + + // errQueueIsReleased will be returned when trying to insert item to a released worker queue. + errQueueIsReleased = errors.New("the queue is released could not accept item anymore") +) + +type loopQueue[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + items []*goWorker[T, U, C, CT, TF] + expiry []*goWorker[T, U, C, CT, TF] + head int + tail int + size int + isFull bool +} + +func newWorkerLoopQueue[T any, U any, C any, CT any, TF pooltask.Context[CT]](size int) *loopQueue[T, U, C, CT, TF] { + return &loopQueue[T, U, C, CT, TF]{ + items: make([]*goWorker[T, U, C, CT, TF], size), + size: size, + } +} + +func (wq *loopQueue[T, U, C, CT, TF]) len() int { + if wq.size == 0 { + return 0 + } + + if wq.head == wq.tail { + if wq.isFull { + return wq.size + } + return 0 + } + + if wq.tail > wq.head { + return wq.tail - wq.head + } + + return wq.size - wq.head + wq.tail +} + +func (wq *loopQueue[T, U, C, CT, TF]) isEmpty() bool { + return wq.head == wq.tail && !wq.isFull +} + +func (wq *loopQueue[T, U, C, CT, TF]) insert(worker *goWorker[T, U, C, CT, TF]) error { + if wq.size == 0 { + return errQueueIsReleased + } + + if wq.isFull { + return errQueueIsFull + } + wq.items[wq.tail] = worker + wq.tail++ + + if wq.tail == wq.size { + wq.tail = 0 + } + if wq.tail == wq.head { + wq.isFull = true + } + + return nil +} + +func (wq *loopQueue[T, U, C, CT, TF]) detach() *goWorker[T, U, C, CT, TF] { + if wq.isEmpty() { + return nil + } + + w := wq.items[wq.head] + wq.items[wq.head] = nil + wq.head++ + if wq.head == wq.size { + wq.head = 0 + } + wq.isFull = false + + return w +} + +func (wq *loopQueue[T, U, C, CT, TF]) retrieveExpiry(duration time.Duration) []*goWorker[T, U, C, CT, TF] { + expiryTime := time.Now().Add(-duration) + index := wq.binarySearch(expiryTime) + if index == -1 { + return nil + } + wq.expiry = wq.expiry[:0] + + if wq.head <= index { + wq.expiry = append(wq.expiry, wq.items[wq.head:index+1]...) + for i := wq.head; i < index+1; i++ { + wq.items[i] = nil + } + } else { + wq.expiry = append(wq.expiry, wq.items[0:index+1]...) + wq.expiry = append(wq.expiry, wq.items[wq.head:]...) + for i := 0; i < index+1; i++ { + wq.items[i] = nil + } + for i := wq.head; i < wq.size; i++ { + wq.items[i] = nil + } + } + head := (index + 1) % wq.size + wq.head = head + if len(wq.expiry) > 0 { + wq.isFull = false + } + + return wq.expiry +} + +// binarySearch is to find the first worker which is idle for more than duration. +func (wq *loopQueue[T, U, C, CT, TF]) binarySearch(expiryTime time.Time) int { + var mid, nlen, basel, tmid int + nlen = len(wq.items) + + // if no need to remove work, return -1 + if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].recycleTime.Load()) { + return -1 + } + + // example + // size = 8, head = 7, tail = 4 + // [ 2, 3, 4, 5, nil, nil, nil, 1] true position + // 0 1 2 3 4 5 6 7 + // tail head + // + // 1 2 3 4 nil nil nil 0 mapped position + // r l + + // base algorithm is a copy from worker_stack + // map head and tail to effective left and right + r := (wq.tail - 1 - wq.head + nlen) % nlen + basel = wq.head + l := 0 + for l <= r { + mid = l + ((r - l) >> 1) + // calculate true mid position from mapped mid position + tmid = (mid + basel + nlen) % nlen + if expiryTime.Before(wq.items[tmid].recycleTime.Load()) { + r = mid - 1 + } else { + l = mid + 1 + } + } + // return true position from mapped position + return (r + basel + nlen) % nlen +} + +func (wq *loopQueue[T, U, C, CT, TF]) reset() { + if wq.isEmpty() { + return + } + +Releasing: + if w := wq.detach(); w != nil { + w.taskBoxCh <- nil + goto Releasing + } + wq.items = wq.items[:0] + wq.size = 0 + wq.head = 0 + wq.tail = 0 +} diff --git a/util/gpool/spmc/worker_loop_queue_test.go b/util/gpool/spmc/worker_loop_queue_test.go new file mode 100644 index 0000000000000..da9bdc8dbc36c --- /dev/null +++ b/util/gpool/spmc/worker_loop_queue_test.go @@ -0,0 +1,184 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/stretchr/testify/require" + atomicutil "go.uber.org/atomic" +) + +func TestNewLoopQueue(t *testing.T) { + size := 100 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + require.EqualValues(t, 0, q.len(), "Len error") + require.Equal(t, true, q.isEmpty(), "IsEmpty error") + require.Nil(t, q.detach(), "Dequeue error") +} + +func TestLoopQueue(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + + for i := 0; i < 5; i++ { + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + if err != nil { + break + } + } + require.EqualValues(t, 5, q.len(), "Len error") + _ = q.detach() + require.EqualValues(t, 4, q.len(), "Len error") + + time.Sleep(time.Second) + + for i := 0; i < 6; i++ { + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + if err != nil { + break + } + } + require.EqualValues(t, 10, q.len(), "Len error") + + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + require.Error(t, err, "Enqueue, error") + + q.retrieveExpiry(time.Second) + require.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) +} + +func TestRotatedArraySearch(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + + expiry1 := time.Now() + + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + + require.EqualValues(t, 0, q.binarySearch(time.Now()), "index should be 0") + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + + expiry2 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + require.EqualValues(t, 0, q.binarySearch(expiry2), "index should be 0") + require.EqualValues(t, 1, q.binarySearch(time.Now()), "index should be 1") + + for i := 0; i < 5; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + + expiry3 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry3)}) + + var err error + for err != errQueueIsFull { + err = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + + require.EqualValues(t, 7, q.binarySearch(expiry3), "index should be 7") + + // rotate + for i := 0; i < 6; i++ { + _ = q.detach() + } + + expiry4 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry4)}) + + for i := 0; i < 4; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + // head = 6, tail = 5, insert direction -> + // [expiry4, time, time, time, time, nil/tail, time/head, time, time, time] + require.EqualValues(t, 0, q.binarySearch(expiry4), "index should be 0") + + for i := 0; i < 3; i++ { + _ = q.detach() + } + expiry5 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry5)}) + + // head = 6, tail = 5, insert direction -> + // [expiry4, time, time, time, time, expiry5, nil/tail, nil, nil, time/head] + require.EqualValues(t, 5, q.binarySearch(expiry5), "index should be 5") + + for i := 0; i < 3; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + // head = 9, tail = 9, insert direction -> + // [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail] + require.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") + + require.EqualValues(t, 9, q.binarySearch(q.items[9].recycleTime.Load()), "index should be 9") + require.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") +} + +func TestRetrieveExpiry(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + expirew := make([]*goWorker[struct{}, struct{}, int, any, pooltask.NilContext], 0) + u, _ := time.ParseDuration("1s") + + // test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time] + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + expirew = append(expirew, q.items[:size/2]...) + time.Sleep(u) + + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + workers := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers, "expired workers aren't right") + + // test [ time, time, time, time, time, time+1s, time+1s, time+1s, time+1s, time+1s] + time.Sleep(u) + + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + expirew = expirew[:0] + expirew = append(expirew, q.items[size/2:]...) + + workers2 := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers2, "expired workers aren't right") + + // test [ time+1s, time+1s, time+1s, nil, nil, time+1s, time+1s, time+1s, time+1s, time+1s] + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + for i := 0; i < size/2; i++ { + _ = q.detach() + } + for i := 0; i < 3; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + time.Sleep(u) + + expirew = expirew[:0] + expirew = append(expirew, q.items[0:3]...) + expirew = append(expirew, q.items[size/2:]...) + + workers3 := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers3, "expired workers aren't right") +} From f483b39c34b91f2f0260748f40d952e0d9a366f2 Mon Sep 17 00:00:00 2001 From: Song Gao Date: Wed, 4 Jan 2023 10:50:20 +0800 Subject: [PATCH 9/9] metrics: add metrics for plan replayer and historical stats (#40271) --- domain/historical_stats.go | 8 ++ domain/plan_replayer.go | 12 +- domain/plan_replayer_dump.go | 9 ++ metrics/grafana/tidb.json | 229 +++++++++++++++++++++++++++++++++++ metrics/metrics.go | 4 + metrics/stats.go | 21 ++++ statistics/handle/dump.go | 16 ++- 7 files changed, 297 insertions(+), 2 deletions(-) diff --git a/domain/historical_stats.go b/domain/historical_stats.go index 04d50608c58c4..ca68319c31ba8 100644 --- a/domain/historical_stats.go +++ b/domain/historical_stats.go @@ -16,10 +16,16 @@ package domain import ( "github.com/pingcap/errors" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics/handle" ) +var ( + generateHistoricalStatsSuccessCounter = metrics.HistoricalStatsCounter.WithLabelValues("generate", "success") + generateHistoricalStatsFailedCounter = metrics.HistoricalStatsCounter.WithLabelValues("generate", "fail") +) + // HistoricalStatsWorker indicates for dump historical stats type HistoricalStatsWorker struct { tblCH chan int64 @@ -52,8 +58,10 @@ func (w *HistoricalStatsWorker) DumpHistoricalStats(tableID int64, statsHandle * return errors.Errorf("cannot get DBInfo by TableID %d", tableID) } if _, err := statsHandle.RecordHistoricalStatsToStorage(dbInfo.Name.O, tblInfo); err != nil { + generateHistoricalStatsFailedCounter.Inc() return errors.Errorf("record table %s.%s's historical stats failed", dbInfo.Name.O, tblInfo.Name.O) } + generateHistoricalStatsSuccessCounter.Inc() return nil } diff --git a/domain/plan_replayer.go b/domain/plan_replayer.go index 54c109cc34dc3..2bbb15772d56c 100644 --- a/domain/plan_replayer.go +++ b/domain/plan_replayer.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/terror" @@ -167,6 +168,13 @@ func insertPlanReplayerSuccessStatusRecord(ctx context.Context, sctx sessionctx. } } +var ( + planReplayerCaptureTaskSendCounter = metrics.PlanReplayerTaskCounter.WithLabelValues("capture", "send") + planReplayerCaptureTaskDiscardCounter = metrics.PlanReplayerTaskCounter.WithLabelValues("capture", "discard") + + planReplayerRegisterTaskGauge = metrics.PlanReplayerRegisterTaskGauge +) + type planReplayerHandle struct { *planReplayerTaskCollectorHandle *planReplayerTaskDumpHandle @@ -181,9 +189,10 @@ func (h *planReplayerHandle) SendTask(task *PlanReplayerDumpTask) bool { if !task.IsContinuesCapture { h.planReplayerTaskCollectorHandle.removeTask(task.PlanReplayerTaskKey) } + planReplayerCaptureTaskSendCounter.Inc() return true default: - // TODO: add metrics here + planReplayerCaptureTaskDiscardCounter.Inc() // directly discard the task if the task channel is full in order not to block the query process logutil.BgLogger().Warn("discard one plan replayer dump task", zap.String("sql-digest", task.SQLDigest), zap.String("plan-digest", task.PlanDigest)) @@ -221,6 +230,7 @@ func (h *planReplayerTaskCollectorHandle) CollectPlanReplayerTask() error { } } h.setupTasks(tasks) + planReplayerRegisterTaskGauge.Set(float64(len(tasks))) return nil } diff --git a/domain/plan_replayer_dump.go b/domain/plan_replayer_dump.go index cad0898c81ef2..bd121b26dd388 100644 --- a/domain/plan_replayer_dump.go +++ b/domain/plan_replayer_dump.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/sessionctx" @@ -145,6 +146,11 @@ func (tne *tableNameExtractor) handleIsView(t *ast.TableName) (bool, error) { return true, nil } +var ( + planReplayerDumpTaskSuccess = metrics.PlanReplayerTaskCounter.WithLabelValues("dump", "success") + planReplayerDumpTaskFailed = metrics.PlanReplayerTaskCounter.WithLabelValues("dump", "fail") +) + // DumpPlanReplayerInfo will dump the information about sqls. // The files will be organized into the following format: /* @@ -212,6 +218,9 @@ func DumpPlanReplayerInfo(ctx context.Context, sctx sessionctx.Context, zap.Strings("sqls", sqls)) } errMsg = err.Error() + planReplayerDumpTaskFailed.Inc() + } else { + planReplayerDumpTaskSuccess.Inc() } err1 := zw.Close() if err1 != nil { diff --git a/metrics/grafana/tidb.json b/metrics/grafana/tidb.json index 0e2ce93934f7c..9637940d4dbd2 100644 --- a/metrics/grafana/tidb.json +++ b/metrics/grafana/tidb.json @@ -14337,6 +14337,235 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "", + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 184 + }, + "hiddenSeries": false, + "id": 236, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.5.11", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "sum(rate(tidb_plan_replayer_task{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\", type=~\"dump\"}[1m])) by (result)", + "format": "time_series", + "interval": "", + "intervalFactor": 2, + "legendFormat": "dump-task-{{result}}", + "refId": "A", + "step": 30 + }, + { + "exemplar": true, + "expr": "sum(rate(tidb_plan_replayer_task{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\", type=~\"capture\"}[1m])) by (result)", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "capture-task-{{result}}", + "refId": "B", + "step": 30 + }, + { + "exemplar": true, + "expr": "avg(tidb_plan_replayer_register_task{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\"})", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "register-task", + "refId": "C" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Plan Replayer Task OPM", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "", + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 184 + }, + "hiddenSeries": false, + "id": 237, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.5.11", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "sum(rate(tidb_statistics_historical_stats{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\", type=~\"generate\"}[1m])) by (result)", + "format": "time_series", + "interval": "", + "intervalFactor": 2, + "legendFormat": "generate-{{result}}", + "refId": "A", + "step": 30 + }, + { + "exemplar": true, + "expr": "sum(rate(tidb_statistics_historical_stats{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\", type=~\"dump\"}[1m])) by (result)", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "dump-{{result}}", + "refId": "B", + "step": 30 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Historical Stats OPM", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, diff --git a/metrics/metrics.go b/metrics/metrics.go index 8f303ba58180e..889f4c5996481 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -217,6 +217,10 @@ func RegisterMetrics() { prometheus.MustRegister(EMACPUUsageGauge) + prometheus.MustRegister(HistoricalStatsCounter) + prometheus.MustRegister(PlanReplayerTaskCounter) + prometheus.MustRegister(PlanReplayerRegisterTaskGauge) + tikvmetrics.InitMetrics(TiDB, TiKVClient) tikvmetrics.RegisterMetrics() tikvmetrics.TiKVPanicCounter = PanicCounter // reset tidb metrics for tikv metrics diff --git a/metrics/stats.go b/metrics/stats.go index 76bd1ec7a936b..5d73753f5669c 100644 --- a/metrics/stats.go +++ b/metrics/stats.go @@ -150,4 +150,25 @@ var ( Name: "stats_healthy", Help: "Gauge of stats healthy", }, []string{LblType}) + + HistoricalStatsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "statistics", + Name: "historical_stats", + Help: "counter of the historical stats operation", + }, []string{LblType, LblResult}) + + PlanReplayerTaskCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "plan_replayer", + Name: "task", + Help: "counter of plan replayer captured task", + }, []string{LblType, LblResult}) + + PlanReplayerRegisterTaskGauge = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "plan_replayer", + Name: "register_task", + Help: "gauge of plan replayer registered task", + }) ) diff --git a/statistics/handle/dump.go b/statistics/handle/dump.go index daaf28ead7573..75f4ee9ea958a 100644 --- a/statistics/handle/dump.go +++ b/statistics/handle/dump.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -131,8 +132,21 @@ func (h *Handle) DumpStatsToJSON(dbName string, tableInfo *model.TableInfo, return h.DumpStatsToJSONBySnapshot(dbName, tableInfo, snapshot, dumpPartitionStats) } +var ( + dumpHistoricalStatsSuccessCounter = metrics.HistoricalStatsCounter.WithLabelValues("dump", "success") + dumpHistoricalStatsFailedCounter = metrics.HistoricalStatsCounter.WithLabelValues("dump", "fail") +) + // DumpHistoricalStatsBySnapshot dumped json tables from mysql.stats_meta_history and mysql.stats_history -func (h *Handle) DumpHistoricalStatsBySnapshot(dbName string, tableInfo *model.TableInfo, snapshot uint64) (*JSONTable, error) { +func (h *Handle) DumpHistoricalStatsBySnapshot(dbName string, tableInfo *model.TableInfo, snapshot uint64) (jt *JSONTable, err error) { + defer func() { + if err == nil { + dumpHistoricalStatsSuccessCounter.Inc() + } else { + dumpHistoricalStatsFailedCounter.Inc() + } + }() + pi := tableInfo.GetPartitionInfo() if pi == nil { return h.tableHistoricalStatsToJSON(tableInfo.ID, snapshot)