diff --git a/pkg/planner/core/plan_cache.go b/pkg/planner/core/plan_cache.go new file mode 100644 index 0000000000000..35c204c777ebe --- /dev/null +++ b/pkg/planner/core/plan_cache.go @@ -0,0 +1,366 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/core/base" + core_metrics "github.com/pingcap/tidb/pkg/planner/core/metrics" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/kvcache" +) + +// PlanCacheKeyTestIssue43667 is only for test. +type PlanCacheKeyTestIssue43667 struct{} + +// PlanCacheKeyTestIssue46760 is only for test. +type PlanCacheKeyTestIssue46760 struct{} + +// PlanCacheKeyTestIssue47133 is only for test. +type PlanCacheKeyTestIssue47133 struct{} + +// SetParameterValuesIntoSCtx sets these parameters into session context. +func SetParameterValuesIntoSCtx(sctx base.PlanContext, isNonPrep bool, markers []ast.ParamMarkerExpr, params []expression.Expression) error { + vars := sctx.GetSessionVars() + vars.PlanCacheParams.Reset() + for i, usingParam := range params { + val, err := usingParam.Eval(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) + if err != nil { + return err + } + if isGetVarBinaryLiteral(sctx, usingParam) { + binVal, convErr := val.ToBytes() + if convErr != nil { + return convErr + } + val.SetBinaryLiteral(binVal) + } + if markers != nil { + param := markers[i].(*driver.ParamMarkerExpr) + param.Datum = val + param.InExecute = true + } + vars.PlanCacheParams.Append(val) + } + if vars.StmtCtx.EnableOptimizerDebugTrace && len(vars.PlanCacheParams.AllParamValues()) > 0 { + vals := vars.PlanCacheParams.AllParamValues() + valStrs := make([]string, len(vals)) + for i, val := range vals { + valStrs[i] = val.String() + } + debugtrace.RecordAnyValuesWithNames(sctx, "Parameter datums for EXECUTE", valStrs) + } + vars.PlanCacheParams.SetForNonPrepCache(isNonPrep) + return nil +} + +func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) error { + vars := sctx.GetSessionVars() + stmtAst := stmt.PreparedAst + vars.StmtCtx.StmtType = stmtAst.StmtType + + // step 1: check parameter number + if len(stmt.Params) != len(params) { + return errors.Trace(plannererrors.ErrWrongParamCount) + } + + // step 2: set parameter values + if err := SetParameterValuesIntoSCtx(sctx.GetPlanCtx(), isNonPrepared, stmt.Params, params); err != nil { + return errors.Trace(err) + } + + // step 3: add metadata lock and check each table's schema version + schemaNotMatch := false + for i := 0; i < len(stmt.dbName); i++ { + tbl, ok := is.TableByID(stmt.tbls[i].Meta().ID) + if !ok { + tblByName, err := is.TableByName(stmt.dbName[i], stmt.tbls[i].Meta().Name) + if err != nil { + return plannererrors.ErrSchemaChanged.GenWithStack("Schema change caused error: %s", err.Error()) + } + delete(stmt.RelateVersion, stmt.tbls[i].Meta().ID) + stmt.tbls[i] = tblByName + stmt.RelateVersion[tblByName.Meta().ID] = tblByName.Meta().Revision + } + newTbl, err := tryLockMDLAndUpdateSchemaIfNecessary(sctx.GetPlanCtx(), stmt.dbName[i], stmt.tbls[i], is) + if err != nil { + schemaNotMatch = true + continue + } + // The revision of tbl and newTbl may not be the same. + // Example: + // The version of stmt.tbls[i] is taken from the prepare statement and is revision v1. + // When stmt.tbls[i] is locked in MDL, the revision of newTbl is also v1. + // The revision of tbl is v2. The reason may have other statements trigger "tryLockMDLAndUpdateSchemaIfNecessary" before, leading to tbl revision update. + if stmt.tbls[i].Meta().Revision != newTbl.Meta().Revision || (tbl != nil && tbl.Meta().Revision != newTbl.Meta().Revision) { + schemaNotMatch = true + } + stmt.tbls[i] = newTbl + stmt.RelateVersion[newTbl.Meta().ID] = newTbl.Meta().Revision + } + + // step 4: check schema version + if schemaNotMatch || stmt.SchemaVersion != is.SchemaMetaVersion() { + // In order to avoid some correctness issues, we have to clear the + // cached plan once the schema version is changed. + // Cached plan in prepared struct does NOT have a "cache key" with + // schema version like prepared plan cache key + stmt.PointGet.pointPlan = nil + stmt.PointGet.columnNames = nil + stmt.PointGet.pointPlanHints = nil + stmt.PointGet.Executor = nil + stmt.PointGet.ColumnInfos = nil + // If the schema version has changed we need to preprocess it again, + // if this time it failed, the real reason for the error is schema changed. + // Example: + // When running update in prepared statement's schema version distinguished from the one of execute statement + // We should reset the tableRefs in the prepared update statements, otherwise, the ast nodes still hold the old + // tableRefs columnInfo which will cause chaos in logic of trying point get plan. (should ban non-public column) + ret := &PreprocessorReturn{InfoSchema: is} + err := Preprocess(ctx, sctx, stmtAst.Stmt, InPrepare, WithPreprocessorReturn(ret)) + if err != nil { + return plannererrors.ErrSchemaChanged.GenWithStack("Schema change caused error: %s", err.Error()) + } + stmt.SchemaVersion = is.SchemaMetaVersion() + } + + // step 5: handle expiration + // If the lastUpdateTime less than expiredTimeStamp4PC, + // it means other sessions have executed 'admin flush instance plan_cache'. + // So we need to clear the current session's plan cache. + // And update lastUpdateTime to the newest one. + expiredTimeStamp4PC := domain.GetDomain(sctx).ExpiredTimeStamp4PC() + if stmt.StmtCacheable && expiredTimeStamp4PC.Compare(vars.LastUpdateTime4PC) > 0 { + sctx.GetSessionPlanCache().DeleteAll() + vars.LastUpdateTime4PC = expiredTimeStamp4PC + } + + return nil +} + +// GetPlanFromPlanCache is the entry point of Plan Cache. +// It tries to get a valid cached plan from plan cache. +// If there is no such a plan, it'll call the optimizer to generate a new one. +// isNonPrepared indicates whether to use the non-prepared plan cache or the prepared plan cache. +func GetPlanFromPlanCache(ctx context.Context, sctx sessionctx.Context, + isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, + params []expression.Expression) (plan base.Plan, names []*types.FieldName, err error) { + if err := planCachePreprocess(ctx, sctx, isNonPrepared, is, stmt, params); err != nil { + return nil, nil, err + } + + var cacheKey string + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + cacheEnabled := false + if isNonPrepared { + stmtCtx.SetCacheType(contextutil.SessionNonPrepared) + cacheEnabled = sctx.GetSessionVars().EnableNonPreparedPlanCache // plan-cache might be disabled after prepare. + } else { + stmtCtx.SetCacheType(contextutil.SessionPrepared) + cacheEnabled = sctx.GetSessionVars().EnablePreparedPlanCache + } + if stmt.StmtCacheable && cacheEnabled { + stmtCtx.EnablePlanCache() + } + if stmt.UncacheableReason != "" { + stmtCtx.WarnSkipPlanCache(stmt.UncacheableReason) + } + + var bindSQL string + if stmtCtx.UseCache() { + var ignoreByBinding bool + bindSQL, ignoreByBinding = bindinfo.MatchSQLBindingForPlanCache(sctx, stmt.PreparedAst.Stmt, &stmt.BindingInfo) + if ignoreByBinding { + stmtCtx.SetSkipPlanCache("ignore plan cache by binding") + } + } + + // In rc or for update read, we need the latest schema version to decide whether we need to + // rebuild the plan. So we set this value in rc or for update read. In other cases, let it be 0. + var latestSchemaVersion int64 + + if stmtCtx.UseCache() { + if sctx.GetSessionVars().IsIsolation(ast.ReadCommitted) || stmt.ForUpdateRead { + // In Rc or ForUpdateRead, we should check if the information schema has been changed since + // last time. If it changed, we should rebuild the plan. Here, we use a different and more + // up-to-date schema version which can lead plan cache miss and thus, the plan will be rebuilt. + latestSchemaVersion = domain.GetDomain(sctx).InfoSchema().SchemaMetaVersion() + } + if cacheKey, err = NewPlanCacheKey(sctx.GetSessionVars(), stmt.StmtText, + stmt.StmtDB, stmt.SchemaVersion, latestSchemaVersion, bindSQL, expression.ExprPushDownBlackListReloadTimeStamp.Load(), stmt.RelateVersion); err != nil { + return nil, nil, err + } + } + + var matchOpts *PlanCacheMatchOpts + if stmtCtx.UseCache() { + var cacheVal kvcache.Value + var hit, isPointPlan bool + if stmt.PointGet.pointPlan != nil { // if it's PointGet Plan, no need to use MatchOpts + cacheVal = &PlanCacheValue{ + Plan: stmt.PointGet.pointPlan, + OutputColumns: stmt.PointGet.columnNames, + stmtHints: stmt.PointGet.pointPlanHints, + } + isPointPlan, hit = true, true + } else { + matchOpts = GetMatchOpts(sctx, is, stmt, params) + // TODO: consider instance-level plan cache + cacheVal, hit = sctx.GetSessionPlanCache().Get(cacheKey, matchOpts) + } + if hit { + if plan, names, ok, err := adjustCachedPlan(sctx, cacheVal.(*PlanCacheValue), isNonPrepared, isPointPlan, cacheKey, bindSQL, is, stmt); err != nil || ok { + return plan, names, err + } + } + } + if matchOpts == nil { + matchOpts = GetMatchOpts(sctx, is, stmt, params) + } + + return generateNewPlan(ctx, sctx, isNonPrepared, is, stmt, cacheKey, latestSchemaVersion, bindSQL, matchOpts) +} + +func adjustCachedPlan(sctx sessionctx.Context, cachedVal *PlanCacheValue, isNonPrepared, isPointPlan bool, + cacheKey string, bindSQL string, is infoschema.InfoSchema, stmt *PlanCacheStmt) (base.Plan, + []*types.FieldName, bool, error) { + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + if !isPointPlan { // keep the prior behavior + if err := checkPreparedPriv(sctx, stmt, is); err != nil { + return nil, nil, false, err + } + } + for tblInfo, unionScan := range cachedVal.TblInfo2UnionScan { + if !unionScan && tableHasDirtyContent(sctx.GetPlanCtx(), tblInfo) { + // TODO we can inject UnionScan into cached plan to avoid invalidating it, though + // rebuilding the filters in UnionScan is pretty trivial. + sctx.GetSessionPlanCache().Delete(cacheKey) + return nil, nil, false, nil + } + } + if !RebuildPlan4CachedPlan(cachedVal.Plan) { + return nil, nil, false, nil + } + sessVars.FoundInPlanCache = true + if len(bindSQL) > 0 { + // When the `len(bindSQL) > 0`, it means we use the binding. + // So we need to record this. + sessVars.FoundInBinding = true + } + if metrics.ResettablePlanCacheCounterFortTest { + metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() + } else { + core_metrics.GetPlanCacheHitCounter(isNonPrepared).Inc() + } + stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + stmtCtx.StmtHints = *cachedVal.stmtHints + return cachedVal.Plan, cachedVal.OutputColumns, true, nil +} + +// generateNewPlan call the optimizer to generate a new plan for current statement +// and try to add it to cache +func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, + stmt *PlanCacheStmt, cacheKey string, latestSchemaVersion int64, bindSQL string, + matchOpts *PlanCacheMatchOpts) (base.Plan, []*types.FieldName, error) { + stmtAst := stmt.PreparedAst + sessVars := sctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + + core_metrics.GetPlanCacheMissCounter(isNonPrepared).Inc() + sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding = true + p, names, err := OptimizeAstNode(ctx, sctx, stmtAst.Stmt, is) + sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding = false + if err != nil { + return nil, nil, err + } + + // check whether this plan is cacheable. + if stmtCtx.UseCache() { + if cacheable, reason := isPlanCacheable(sctx.GetPlanCtx(), p, len(matchOpts.ParamTypes), len(matchOpts.LimitOffsetAndCount), matchOpts.HasSubQuery); !cacheable { + stmtCtx.SetSkipPlanCache(reason) + } + } + + // put this plan into the plan cache. + if stmtCtx.UseCache() { + // rebuild key to exclude kv.TiFlash when stmt is not read only + if _, isolationReadContainTiFlash := sessVars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmtAst.Stmt, sessVars) { + delete(sessVars.IsolationReadEngines, kv.TiFlash) + if cacheKey, err = NewPlanCacheKey(sessVars, stmt.StmtText, stmt.StmtDB, + stmt.SchemaVersion, latestSchemaVersion, bindSQL, expression.ExprPushDownBlackListReloadTimeStamp.Load(), stmt.RelateVersion); err != nil { + return nil, nil, err + } + sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} + } + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts, &stmtCtx.StmtHints) + stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) + stmtCtx.SetPlan(p) + stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + sctx.GetSessionPlanCache().Put(cacheKey, cached, matchOpts) + if _, ok := p.(*PointGetPlan); ok { + stmt.PointGet.pointPlan = p + stmt.PointGet.columnNames = names + stmt.PointGet.pointPlanHints = stmtCtx.StmtHints.Clone() + } + } + sessVars.FoundInPlanCache = false + return p, names, err +} + +// checkPreparedPriv checks the privilege of the prepared statement +func checkPreparedPriv(sctx sessionctx.Context, stmt *PlanCacheStmt, is infoschema.InfoSchema) error { + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + visitInfo := VisitInfo4PrivCheck(is, stmt.PreparedAst.Stmt, stmt.VisitInfos) + if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil { + return err + } + } + err := CheckTableLock(sctx, is, stmt.VisitInfos) + return err +} + +// IsSafeToReusePointGetExecutor checks whether this is a PointGet Plan and safe to reuse its executor. +func IsSafeToReusePointGetExecutor(sctx sessionctx.Context, is infoschema.InfoSchema, stmt *PlanCacheStmt) bool { + if staleread.IsStmtStaleness(sctx) { + return false + } + // check auto commit + if !IsAutoCommitTxn(sctx.GetSessionVars()) { + return false + } + if stmt.SchemaVersion != is.SchemaMetaVersion() { + return false + } + return true +} diff --git a/pkg/server/internal/testserverclient/BUILD.bazel b/pkg/server/internal/testserverclient/BUILD.bazel new file mode 100644 index 0000000000000..322b6a91f1446 --- /dev/null +++ b/pkg/server/internal/testserverclient/BUILD.bazel @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "testserverclient", + srcs = ["server_client.go"], + importpath = "github.com/pingcap/tidb/pkg/server/internal/testserverclient", + visibility = ["//pkg/server:__subpackages__"], + deps = [ + "//pkg/ddl/util/callback", + "//pkg/domain", + "//pkg/errno", + "//pkg/kv", + "//pkg/metrics", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/server", + "//pkg/testkit", + "//pkg/testkit/testenv", + "//pkg/util/versioninfo", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_log//:log", + "@com_github_prometheus_client_model//go", + "@com_github_stretchr_testify//require", + "@org_golang_x_text//encoding/simplifiedchinese", + "@org_uber_go_zap//:zap", + ], +) diff --git a/pkg/server/tests/servertestkit/BUILD.bazel b/pkg/server/tests/servertestkit/BUILD.bazel new file mode 100644 index 0000000000000..65010930b6cfd --- /dev/null +++ b/pkg/server/tests/servertestkit/BUILD.bazel @@ -0,0 +1,27 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "servertestkit", + srcs = ["testkit.go"], + importpath = "github.com/pingcap/tidb/pkg/server/tests/servertestkit", + visibility = ["//visibility:public"], + deps = [ + "//pkg/config", + "//pkg/domain", + "//pkg/kv", + "//pkg/server", + "//pkg/server/internal/testserverclient", + "//pkg/server/internal/testutil", + "//pkg/server/internal/util", + "//pkg/session", + "//pkg/store/mockstore", + "//pkg/testkit", + "//pkg/util/cpuprofile", + "//pkg/util/topsql/collector/mock", + "//pkg/util/topsql/state", + "@com_github_cockroachdb_errors//:errors", + "@com_github_stretchr_testify//require", + "@io_opencensus_go//stats/view", + "@org_uber_go_zap//:zap", + ], +) diff --git a/pkg/server/tests/servertestkit/testkit.go b/pkg/server/tests/servertestkit/testkit.go new file mode 100644 index 0000000000000..348ac91a3bee4 --- /dev/null +++ b/pkg/server/tests/servertestkit/testkit.go @@ -0,0 +1,200 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package servertestkit + +import ( + "context" + "database/sql" + "sync" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + srv "github.com/pingcap/tidb/pkg/server" + "github.com/pingcap/tidb/pkg/server/internal/testserverclient" + "github.com/pingcap/tidb/pkg/server/internal/testutil" + "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/util/cpuprofile" + "github.com/pingcap/tidb/pkg/util/topsql/collector/mock" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/stretchr/testify/require" + "go.opencensus.io/stats/view" + "go.uber.org/zap" +) + +// TidbTestSuite is a test suite for tidb +type TidbTestSuite struct { + *testserverclient.TestServerClient + Tidbdrv *srv.TiDBDriver + Server *srv.Server + Domain *domain.Domain + Store kv.Storage +} + +// CreateTidbTestSuite creates a test suite for tidb +func CreateTidbTestSuite(t *testing.T) *TidbTestSuite { + cfg := newTestConfig() + return CreateTidbTestSuiteWithCfg(t, cfg) +} + +// CreateTidbTestSuiteWithDDLLease creates a test suite with DDL lease for tidb. +func CreateTidbTestSuiteWithDDLLease(t *testing.T, ddlLease string) *TidbTestSuite { + cfg := newTestConfig() + cfg.Lease = ddlLease + return CreateTidbTestSuiteWithCfg(t, cfg) +} + +func newTestConfig() *config.Config { + cfg := util.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = 0 + cfg.Status.RecordDBLabel = true + cfg.Performance.TCPKeepAlive = true + return cfg +} + +// parseDuration parses lease argument string. +func parseDuration(lease string) (time.Duration, error) { + dur, err := time.ParseDuration(lease) + if err != nil { + dur, err = time.ParseDuration(lease + "s") + } + if err != nil || dur < 0 { + return 0, errors.Newf("invalid lease duration", zap.String("lease", lease)) + } + return dur, nil +} + +// CreateTidbTestSuiteWithCfg creates a test suite for tidb with config +func CreateTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *TidbTestSuite { + ts := &TidbTestSuite{TestServerClient: testserverclient.NewTestServerClient()} + + // setup tidbTestSuite + var err error + ts.Store, err = mockstore.NewMockStore() + session.DisableStats4Test() + require.NoError(t, err) + ddlLeaseDuration, err := parseDuration(cfg.Lease) + require.NoError(t, err) + session.SetSchemaLease(ddlLeaseDuration) + ts.Domain, err = session.BootstrapSession(ts.Store) + require.NoError(t, err) + ts.Tidbdrv = srv.NewTiDBDriver(ts.Store) + + srv.RunInGoTestChan = make(chan struct{}) + server, err := srv.NewServer(cfg, ts.Tidbdrv) + require.NoError(t, err) + + ts.Server = server + ts.Server.SetDomain(ts.Domain) + ts.Domain.InfoSyncer().SetSessionManager(ts.Server) + go func() { + err := ts.Server.Run(nil) + require.NoError(t, err) + }() + <-srv.RunInGoTestChan + ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + ts.WaitUntilServerOnline() + + t.Cleanup(func() { + if ts.Domain != nil { + ts.Domain.Close() + } + if ts.Server != nil { + ts.Server.Close() + } + if ts.Store != nil { + require.NoError(t, ts.Store.Close()) + } + view.Stop() + }) + return ts +} + +type tidbTestTopSQLSuite struct { + *TidbTestSuite +} + +// CreateTidbTestTopSQLSuite creates a test suite for top-sql test. +func CreateTidbTestTopSQLSuite(t *testing.T) *tidbTestTopSQLSuite { + base := CreateTidbTestSuite(t) + + ts := &tidbTestTopSQLSuite{base} + + // Initialize global variable for top-sql test. + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + dbt := testkit.NewDBTestKit(t, db) + topsqlstate.GlobalState.PrecisionSeconds.Store(1) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(2) + dbt.MustExec("set @@global.tidb_top_sql_max_time_series_count=5;") + + require.NoError(t, cpuprofile.StartCPUProfiler()) + t.Cleanup(func() { + cpuprofile.StopCPUProfiler() + topsqlstate.GlobalState.PrecisionSeconds.Store(topsqlstate.DefTiDBTopSQLPrecisionSeconds) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(topsqlstate.DefTiDBTopSQLReportIntervalSeconds) + view.Stop() + }) + return ts +} + +// TestCase is to run the test case for top-sql test. +func (ts *tidbTestTopSQLSuite) TestCase(t *testing.T, mc *mock.TopSQLCollector, execFn func(db *sql.DB), checkFn func()) { + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + wg.Add(1) + go func() { + defer wg.Done() + ts.loopExec(ctx, t, execFn) + }() + + checkFn() + cancel() + wg.Wait() + mc.Reset() +} + +func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use topsql;") + for { + select { + case <-ctx.Done(): + return + default: + } + fn(db) + } +} diff --git a/server/server_test.go b/server/server_test.go index 202b6fe0609c8..2d46b19a98a82 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -35,12 +35,27 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" +<<<<<<< HEAD:server/server_test.go "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/versioninfo" +======= + "github.com/pingcap/tidb/pkg/ddl/util/callback" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + tmysql "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/server" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testenv" + "github.com/pingcap/tidb/pkg/util/versioninfo" + dto "github.com/prometheus/client_model/go" +>>>>>>> 9aeaa76c5cb (*: fix a bug that update statement uses point get and update plan with different tblInfo (#54183)):pkg/server/internal/testserverclient/server_client.go "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -2535,6 +2550,7 @@ func (cli *testServerClient) runTestLoadDataReplace(t *testing.T) { }) } +<<<<<<< HEAD:server/server_test.go func (cli *testServerClient) runTestLoadDataReplaceNonclusteredPK(t *testing.T) { fp1, err := os.CreateTemp("", "a.dat") require.NoError(t, err) @@ -2660,3 +2676,184 @@ func (cli *testServerClient) RunTestStmtCountLimit(t *testing.T) { require.Equal(t, 5, count) }) } +======= +func (cli *TestServerClient) getNewDB(t *testing.T, overrider configOverrider) *testkit.DBTestKit { + db, err := sql.Open("mysql", cli.GetDSN(overrider)) + require.NoError(t, err) + + return testkit.NewDBTestKit(t, db) +} + +func MustExec(ctx context.Context, t *testing.T, conn *sql.Conn, sql string) { + _, err := conn.QueryContext(ctx, sql) + require.NoError(t, err) +} + +func MustQuery(ctx context.Context, t *testing.T, cli *TestServerClient, conn *sql.Conn, sql string) { + rs, err := conn.QueryContext(ctx, sql) + require.NoError(t, err) + if rs != nil { + cli.Rows(t, rs) + rs.Close() + } +} + +type sqlWithErr struct { + stmt *sql.Stmt + sql string +} + +type expectQuery struct { + sql string + rows []string +} + +func (cli *TestServerClient) RunTestIssue53634(t *testing.T, dom *domain.Domain) { + cli.RunTests(t, func(config *mysql.Config) { + config.MaxAllowedPacket = 1024 + }, func(dbt *testkit.DBTestKit) { + ctx := context.Background() + + conn, err := dbt.GetDB().Conn(ctx) + require.NoError(t, err) + MustExec(ctx, t, conn, "create database test_db_state default charset utf8 default collate utf8_bin") + MustExec(ctx, t, conn, "use test_db_state") + MustExec(ctx, t, conn, `CREATE TABLE stock ( + a int NOT NULL, + b char(30) NOT NULL, + c int, + d char(64), + PRIMARY KEY(a,b) +) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin COMMENT='…comment'; +`) + MustExec(ctx, t, conn, "insert into stock values(1, 'a', 11, 'x'), (2, 'b', 22, 'y')") + MustExec(ctx, t, conn, "alter table stock add column cct_1 int default 10") + MustExec(ctx, t, conn, "alter table stock modify cct_1 json") + MustExec(ctx, t, conn, "alter table stock add column adc_1 smallint") + defer MustExec(ctx, t, conn, "drop database test_db_state") + + sqls := make([]sqlWithErr, 5) + sqls[0] = sqlWithErr{nil, "begin"} + sqls[1] = sqlWithErr{nil, "SELECT a, c, d from stock where (a, b) IN ((?, ?),(?, ?)) FOR UPDATE"} + sqls[2] = sqlWithErr{nil, "UPDATE stock SET c = ? WHERE a= ? AND b = 'a'"} + sqls[3] = sqlWithErr{nil, "UPDATE stock SET c = ?, d = 'z' WHERE a= ? AND b = 'b'"} + sqls[4] = sqlWithErr{nil, "commit"} + dropColumnSQL := "alter table stock drop column cct_1" + query := &expectQuery{sql: "select * from stock;", rows: []string{"1 a 101 x \n2 b 102 z "}} + runTestInSchemaState(t, conn, cli, dom, model.StateWriteReorganization, true, dropColumnSQL, sqls, query) + }) +} + +func runTestInSchemaState( + t *testing.T, + conn *sql.Conn, + cli *TestServerClient, + dom *domain.Domain, + state model.SchemaState, + isOnJobUpdated bool, + dropColumnSQL string, + sqlWithErrs []sqlWithErr, + expectQuery *expectQuery, +) { + ctx := context.Background() + MustExec(ctx, t, conn, "use test_db_state") + + callback := &callback.TestDDLCallback{Do: dom} + prevState := model.StateNone + var checkErr error + dbt := cli.getNewDB(t, func(config *mysql.Config) { + config.MaxAllowedPacket = 1024 + }) + conn1, err := dbt.GetDB().Conn(ctx) + require.NoError(t, err) + defer func() { + err := dbt.GetDB().Close() + require.NoError(t, err) + }() + MustExec(ctx, t, conn1, "use test_db_state") + + for i, sqlWithErr := range sqlWithErrs { + // Start the test txn. + // Step 1: begin(when i = 0). + if i == 0 || i == len(sqlWithErrs)-1 { + sqlWithErr := sqlWithErrs[i] + MustExec(ctx, t, conn1, sqlWithErr.sql) + } else { + // Step 2: prepare stmts. + // SELECT a, c, d from stock where (a, b) IN ((?, ?),(?, ?)) FOR UPDATE + // UPDATE stock SET c = ? WHERE a= ? AND b = 'a' + // UPDATE stock SET c = ?, d = 'z' WHERE a= ? AND b = 'b' + stmt, err := conn1.PrepareContext(ctx, sqlWithErr.sql) + require.NoError(t, err) + sqlWithErr.stmt = stmt + sqlWithErrs[i] = sqlWithErr + } + } + + // Step 3: begin. + sqlWithErr := sqlWithErrs[0] + MustExec(ctx, t, conn1, sqlWithErr.sql) + + prevState = model.StateNone + state = model.StateWriteOnly + cbFunc1 := func(job *model.Job) { + if jobStateOrLastSubJobState(job) == prevState || checkErr != nil { + return + } + prevState = jobStateOrLastSubJobState(job) + if prevState != state { + return + } + // Step 4: exec stmts in write-only state (dropping a colum). + // SELECT a, c, d from stock where (a, b) IN ((?, ?),(?, ?)) FOR UPDATE, args:(1,"a"),(2,"b") + // UPDATE stock SET c = ? WHERE a= ? AND b = 'a', args:(100+1, 1) + // UPDATE stock SET c = ?, d = 'z' WHERE a= ? AND b = 'b', args:(100+2, 2) + // commit. + sqls := sqlWithErrs[1:] + for i, sqlWithErr := range sqls { + if i == 0 { + _, err = sqlWithErr.stmt.ExecContext(ctx, 1, "a", 2, "b") + require.NoError(t, err) + } else if i == 1 || i == 2 { + _, err = sqlWithErr.stmt.ExecContext(ctx, 100+i, i) + require.NoError(t, err) + } else { + MustQuery(ctx, t, cli, conn1, sqlWithErr.sql) + } + } + } + if isOnJobUpdated { + callback.OnJobUpdatedExported.Store(&cbFunc1) + } else { + callback.OnJobRunBeforeExported = cbFunc1 + } + d := dom.DDL() + originalCallback := d.GetHook() + d.SetHook(callback) + MustExec(ctx, t, conn, dropColumnSQL) + require.NoError(t, checkErr) + + // Check the result. + // select * from stock + if expectQuery != nil { + rs, err := conn.QueryContext(ctx, expectQuery.sql) + require.NoError(t, err) + if expectQuery.rows == nil { + require.Nil(t, rs) + } else { + cli.CheckRows(t, rs, expectQuery.rows[0]) + } + } + d.SetHook(originalCallback) +} + +func jobStateOrLastSubJobState(job *model.Job) model.SchemaState { + if job.Type == model.ActionMultiSchemaChange && job.MultiSchemaInfo != nil { + subs := job.MultiSchemaInfo.SubJobs + return subs[len(subs)-1].SchemaState + } + return job.SchemaState +} + +//revive:enable:exported +>>>>>>> 9aeaa76c5cb (*: fix a bug that update statement uses point get and update plan with different tblInfo (#54183)):pkg/server/internal/testserverclient/server_client.go diff --git a/server/tidb_test.go b/server/tidb_test.go index 259942d929a52..0facbca073655 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -3224,8 +3224,183 @@ func TestProxyProtocolWithIpNoFallbackable(t *testing.T) { db.Close() } +<<<<<<< HEAD:server/tidb_test.go func TestLoadData(t *testing.T) { ts := createTidbTestSuite(t) ts.runTestLoadDataReplace(t) ts.runTestLoadDataReplaceNonclusteredPK(t) +======= +func TestConnectionWillNotLeak(t *testing.T) { + cfg := util2.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = false + // Setup proxy protocol config + cfg.ProxyProtocol.Networks = "*" + cfg.ProxyProtocol.Fallbackable = false + + ts := servertestkit.CreateTidbTestSuite(t) + + cli := testserverclient.NewTestServerClient() + cli.Port = testutil.GetPortFromTCPAddr(ts.Server.ListenAddr()) + dsn := cli.GetDSN(func(config *mysql.Config) { + config.User = "root" + config.DBName = "test" + }) + db, err := sql.Open("mysql", dsn) + require.Nil(t, err) + db.SetMaxOpenConns(100) + db.SetMaxIdleConns(0) + + // create 100 connections + conns := make([]*sql.Conn, 0, 100) + for len(conns) < 100 { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + conns = append(conns, conn) + } + require.Eventually(t, func() bool { + runtime.GC() + return server2.ConnectionInMemCounterForTest.Load() == int64(100) + }, time.Minute, time.Millisecond*100) + + // run a simple query on each connection and close it + // this cannot ensure the connection will not leak for any kinds of requests + var wg sync.WaitGroup + for _, conn := range conns { + wg.Add(1) + conn := conn + go func() { + rows, err := conn.QueryContext(context.Background(), "SELECT 2023") + require.NoError(t, err) + var result int + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&result)) + require.Equal(t, result, 2023) + require.NoError(t, rows.Close()) + // `db.Close` will not close already grabbed connection, so it's still needed to close the connection here. + require.NoError(t, conn.Close()) + wg.Done() + }() + } + wg.Wait() + + require.NoError(t, db.Close()) + require.Eventually(t, func() bool { + runtime.GC() + count := server2.ConnectionInMemCounterForTest.Load() + return count == 0 + }, time.Minute, time.Millisecond*100) +} + +func TestPrepareCount(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + + qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil) + require.NoError(t, err) + prepareCnt := atomic.LoadInt64(&variable.PreparedStmtCount) + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test;") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "drop table if exists t1") + require.NoError(t, err) + _, err = Execute(ctx, qctx, "create table t1 (id int)") + require.NoError(t, err) + stmt, _, _, err := qctx.Prepare("insert into t1 values (?)") + require.NoError(t, err) + require.Equal(t, prepareCnt+1, atomic.LoadInt64(&variable.PreparedStmtCount)) + require.NoError(t, err) + err = qctx.GetStatement(stmt.ID()).Close() + require.NoError(t, err) + require.Equal(t, prepareCnt, atomic.LoadInt64(&variable.PreparedStmtCount)) + require.NoError(t, qctx.Close()) +} + +func TestSQLModeIsLoadedBeforeQuery(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + ts.RunTestSQLModeIsLoadedBeforeQuery(t) +} + +func TestConnectionCount(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + ts.RunTestConnectionCount(t) +} + +func TestTypeAndCharsetOfSendLongData(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + ts.RunTestTypeAndCharsetOfSendLongData(t) +} + +func TestIssue53634(t *testing.T) { + ts := servertestkit.CreateTidbTestSuiteWithDDLLease(t, "20s") + ts.RunTestIssue53634(t, ts.Domain) +} + +func TestAuthSocket(t *testing.T) { + defer server2.ClearOSUserForAuthSocket() + + cfg := util2.NewTestConfig() + cfg.Socket = filepath.Join(t.TempDir(), "authsock.sock") + cfg.Port = 0 + cfg.Status.StatusPort = 0 + ts := servertestkit.CreateTidbTestSuiteWithCfg(t, cfg) + ts.WaitUntilServerCanConnect() + + ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("CREATE USER 'u1'@'%' IDENTIFIED WITH auth_socket;") + dbt.MustExec("CREATE USER 'u2'@'%' IDENTIFIED WITH auth_socket AS 'sockuser'") + dbt.MustExec("CREATE USER 'sockuser'@'%' IDENTIFIED WITH auth_socket;") + }) + + // network login should be denied + for _, uname := range []string{"u1", "u2", "u3"} { + server2.MockOSUserForAuthSocket(uname) + db, err := sql.Open("mysql", ts.GetDSN(func(config *mysql.Config) { + config.User = uname + })) + require.NoError(t, err) + _, err = db.Conn(context.TODO()) + require.EqualError(t, + err, + fmt.Sprintf("Error 1045 (28000): Access denied for user '%s'@'127.0.0.1' (using password: NO)", uname), + ) + require.NoError(t, db.Close()) + } + + socketAuthConf := func(user string) func(*mysql.Config) { + return func(config *mysql.Config) { + config.User = user + config.Net = "unix" + config.Addr = cfg.Socket + config.DBName = "" + } + } + + server2.MockOSUserForAuthSocket("sockuser") + + // mysql username that is different with the OS user should be rejected. + db, err := sql.Open("mysql", ts.GetDSN(socketAuthConf("u1"))) + require.NoError(t, err) + _, err = db.Conn(context.TODO()) + require.EqualError(t, err, "Error 1045 (28000): Access denied for user 'u1'@'localhost' (using password: YES)") + require.NoError(t, db.Close()) + + // mysql username that is the same with the OS user should be accepted. + ts.RunTests(t, socketAuthConf("sockuser"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "sockuser@%") + }) + + // When a user is created with `IDENTIFIED WITH auth_socket AS ...`. + // It should be accepted when username or as string is the same with OS user. + ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "u2@%") + }) + + server2.MockOSUserForAuthSocket("u2") + ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "u2@%") + }) +>>>>>>> 9aeaa76c5cb (*: fix a bug that update statement uses point get and update plan with different tblInfo (#54183)):pkg/server/tests/commontest/tidb_test.go }