diff --git a/DEPS.bzl b/DEPS.bzl index 042f4025e5db0..322bc9d8025ce 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -3603,8 +3603,8 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", - sum = "h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM=", - version = "v2.0.4-0.20241125064444-5f59e4e34c62", + sum = "h1:P6bhZG2yFFuKYvOpfltUbt89sbHohq4BAv2P4GB3fL8=", + version = "v2.0.4-0.20250109055446-ccec7efbf0f7", ) go_repository( name = "com_github_tikv_pd_client", diff --git a/ddl/cluster.go b/ddl/cluster.go index 598f1cda7de51..cf482cb6e8efe 100644 --- a/ddl/cluster.go +++ b/ddl/cluster.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" - "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -108,16 +107,12 @@ func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { // ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS). func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error { - currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) - // If we fail to calculate currentTS from local time, fallback to get a timestamp from PD. + currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) if err != nil { - metrics.ValidateReadTSFromPDCount.Inc() - currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) - if err != nil { - return errors.Errorf("fail to validate flashback timestamp: %v", err) - } - currentTS = currentVer.Ver + return errors.Errorf("fail to validate flashback timestamp: %v", err) } + currentTS := currentVer.Ver + oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS) if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) { return errors.Errorf("cannot set flashback timestamp to future time") diff --git a/executor/executor_test.go b/executor/executor_test.go index 1cca04941e212..2cd7dca04fb20 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -4776,7 +4776,7 @@ func TestStaleReadAtFutureTime(t *testing.T) { tk := testkit.NewTestKit(t, store) // Setting tx_read_ts to a time in the future will fail. (One day before the 2038 problem) - tk.MustGetErrMsg("set @@tx_read_ts = '2038-01-18 03:14:07'", "cannot set read timestamp to a future time") + tk.MustContainErrMsg("set @@tx_read_ts = '2038-01-18 03:14:07'", "cannot set read timestamp to a future time") // TxnReadTS Is not updated if check failed. require.Zero(t, tk.Session().GetSessionVars().TxnReadTS.PeakTxnReadTS()) } diff --git a/executor/set.go b/executor/set.go index 75e4938d41725..da6e1a58198b7 100644 --- a/executor/set.go +++ b/executor/set.go @@ -197,10 +197,8 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres newSnapshotTS := getSnapshotTSByName() newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS if newSnapshotIsSet { - if name == variable.TiDBTxnReadTS { - err = sessionctx.ValidateStaleReadTS(ctx, e.ctx, newSnapshotTS) - } else { - err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx, newSnapshotTS) + err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx.GetStore(), newSnapshotTS) + if name != variable.TiDBTxnReadTS { // Also check gc safe point for snapshot read. // We don't check snapshot with gc safe point for read_ts // Client-go will automatically check the snapshotTS with gc safe point. It's unnecessary to check gc safe point during set executor. diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index e621c33ccc675..5b141f75e443a 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -17,6 +17,7 @@ package executor_test import ( "context" "fmt" + "strconv" "testing" "time" @@ -1406,14 +1407,30 @@ func TestStaleTSO(t *testing.T) { tk.MustExec("create table t (id int)") tk.MustExec("insert into t values(1)") + ts1, err := strconv.ParseUint(tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows()[0][0].(string), 10, 64) + require.NoError(t, err) + + // Wait until the physical advances for 1s + var currentTS uint64 + for { + tk.MustExec("begin") + currentTS, err = strconv.ParseUint(tk.MustQuery("select @@tidb_current_ts").Rows()[0][0].(string), 10, 64) + require.NoError(t, err) + tk.MustExec("rollback") + if oracle.GetTimeFromTS(currentTS).After(oracle.GetTimeFromTS(ts1).Add(time.Second)) { + break + } + time.Sleep(time.Millisecond * 100) + } asOfExprs := []string{ - "now(3) - interval 1 second", - "current_time() - interval 1 second", - "curtime() - interval 1 second", + "now(3) - interval 10 second", + "current_time() - interval 10 second", + "curtime() - interval 10 second", } - nextTSO := oracle.GoTimeToTS(time.Now().Add(2 * time.Second)) + nextPhysical := oracle.GetPhysical(oracle.GetTimeFromTS(currentTS).Add(10 * time.Second)) + nextTSO := oracle.ComposeTS(nextPhysical, oracle.ExtractLogical(currentTS)) require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO))) defer failpoint.Disable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO") for _, expr := range asOfExprs { diff --git a/go.mod b/go.mod index 6945c5367ea86..b891b04dabb67 100644 --- a/go.mod +++ b/go.mod @@ -90,7 +90,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tdakkota/asciicheck v0.1.1 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 + github.com/tikv/client-go/v2 v2.0.4-0.20250109055446-ccec7efbf0f7 github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 github.com/twmb/murmur3 v1.1.3 diff --git a/go.sum b/go.sum index 3e24ce3c05608..a8833b4242b28 100644 --- a/go.sum +++ b/go.sum @@ -948,8 +948,8 @@ github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpR github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= -github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62 h1:0YcirnuxtXC9eQRb231im1M5w/n7JFuOo0IgE/K9ffM= -github.com/tikv/client-go/v2 v2.0.4-0.20241125064444-5f59e4e34c62/go.mod h1:mmVCLP2OqWvQJPOIevQPZvGphzh/oq9vv8J5LDfpadQ= +github.com/tikv/client-go/v2 v2.0.4-0.20250109055446-ccec7efbf0f7 h1:P6bhZG2yFFuKYvOpfltUbt89sbHohq4BAv2P4GB3fL8= +github.com/tikv/client-go/v2 v2.0.4-0.20250109055446-ccec7efbf0f7/go.mod h1:mmVCLP2OqWvQJPOIevQPZvGphzh/oq9vv8J5LDfpadQ= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05 h1:e4hLUKfgfPeJPZwOfU+/I/03G0sn6IZqVcbX/5o+hvM= github.com/tikv/pd/client v0.0.0-20230904040343-947701a32c05/go.mod h1:MLIl+d2WbOF4A3U88WKtyXrQQW417wZDDvBcq2IW9bQ= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= diff --git a/pkg/sessionctx/BUILD.bazel b/pkg/sessionctx/BUILD.bazel new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index ae2ed489cb7df..5aa35357eb1cd 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -515,7 +515,7 @@ type PlanCacheStmt struct { SQLDigest *parser.Digest PlanDigest *parser.Digest ForUpdateRead bool - SnapshotTSEvaluator func(sessionctx.Context) (uint64, error) + SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error) NormalizedSQL4PC string SQLDigest4PC string diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index e6b0ce5e9e4f5..8ce34e4ce3a26 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3399,7 +3399,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } - if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { return nil, err } p.StaleTxnStartTS = startTS @@ -3413,7 +3413,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } - if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil { + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { return nil, err } p.StaleTxnStartTS = startTS diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 785ff61a615f3..ae64a9f23b188 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -168,7 +168,7 @@ var _ = PreprocessorReturn{}.initedLastSnapshotTS type PreprocessorReturn struct { initedLastSnapshotTS bool IsStaleness bool - SnapshotTSEvaluator func(sessionctx.Context) (uint64, error) + SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error) // LastSnapshotTS is the last evaluated snapshotTS if any // otherwise it defaults to zero LastSnapshotTS uint64 diff --git a/sessionctx/BUILD.bazel b/sessionctx/BUILD.bazel index 800001fd426b3..7f893d0412173 100644 --- a/sessionctx/BUILD.bazel +++ b/sessionctx/BUILD.bazel @@ -8,7 +8,6 @@ go_library( deps = [ "//extension", "//kv", - "//metrics", "//parser/model", "//sessionctx/sessionstates", "//sessionctx/variable", @@ -17,7 +16,6 @@ go_library( "//util/kvcache", "//util/sli", "//util/topsql/stmtstats", - "@com_github_pingcap_errors//:errors", "@com_github_pingcap_kvproto//pkg/kvrpcpb", "@com_github_pingcap_tipb//go-binlog", "@com_github_tikv_client_go_v2//oracle", diff --git a/sessionctx/context.go b/sessionctx/context.go index 35eb7ba68ca1d..281c3f8d24a19 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -17,13 +17,10 @@ package sessionctx import ( "context" "fmt" - "time" - "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/extension" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/sessionctx/sessionstates" "github.com/pingcap/tidb/sessionctx/variable" @@ -223,44 +220,8 @@ const ( ) // ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp -func ValidateSnapshotReadTS(ctx context.Context, sctx Context, readTS uint64) error { - latestTS, err := sctx.GetStore().GetOracle().GetLowResolutionTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) - // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check - if err != nil || readTS > latestTS { - metrics.ValidateReadTSFromPDCount.Inc() - currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) - if err != nil { - return errors.Errorf("fail to validate read timestamp: %v", err) - } - if readTS > currentVer.Ver { - return errors.Errorf("cannot set read timestamp to a future time") - } - } - return nil -} - -// How far future from now ValidateStaleReadTS allows at most -const allowedTimeFromNow = 100 * time.Millisecond - -// ValidateStaleReadTS validates that readTS does not exceed the current time not strictly. -func ValidateStaleReadTS(ctx context.Context, sctx Context, readTS uint64) error { - currentTS, err := sctx.GetSessionVars().StmtCtx.GetStaleTSO() - if currentTS == 0 || err != nil { - currentTS, err = sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) - } - // If we fail to calculate currentTS from local time, fallback to get a timestamp from PD - if err != nil { - metrics.ValidateReadTSFromPDCount.Inc() - currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) - if err != nil { - return errors.Errorf("fail to validate read timestamp: %v", err) - } - currentTS = currentVer.Ver - } - if oracle.GetTimeFromTS(readTS).After(oracle.GetTimeFromTS(currentTS).Add(allowedTimeFromNow)) { - return errors.Errorf("cannot set read timestamp to a future time") - } - return nil +func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64) error { + return store.GetOracle().ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) } // SysProcTracker is used to track background sys processes diff --git a/sessiontxn/staleread/processor.go b/sessiontxn/staleread/processor.go index af91ffd1b175e..278d1b158a599 100644 --- a/sessiontxn/staleread/processor.go +++ b/sessiontxn/staleread/processor.go @@ -30,7 +30,7 @@ import ( var _ Processor = &staleReadProcessor{} // StalenessTSEvaluator is a function to get staleness ts -type StalenessTSEvaluator func(sctx sessionctx.Context) (uint64, error) +type StalenessTSEvaluator func(ctx context.Context, sctx sessionctx.Context) (uint64, error) // Processor is an interface used to process stale read type Processor interface { @@ -100,7 +100,7 @@ func (p *baseProcessor) setEvaluatedTS(ts uint64) (err error) { return err } - return p.setEvaluatedValues(ts, is, func(sctx sessionctx.Context) (uint64, error) { + return p.setEvaluatedValues(ts, is, func(_ context.Context, sctx sessionctx.Context) (uint64, error) { return ts, nil }) } @@ -116,7 +116,7 @@ func (p *baseProcessor) setEvaluatedTSWithoutEvaluator(ts uint64) (err error) { } func (p *baseProcessor) setEvaluatedEvaluator(evaluator StalenessTSEvaluator) error { - ts, err := evaluator(p.sctx) + ts, err := evaluator(p.ctx, p.sctx) if err != nil { return err } @@ -167,10 +167,10 @@ func (p *staleReadProcessor) OnSelectTable(tn *ast.TableName) error { } // If `stmtAsOfTS` is not 0, it means we use 'select ... from xxx as of timestamp ...' - evaluateTS := func(sctx sessionctx.Context) (uint64, error) { - return parseAndValidateAsOf(context.Background(), p.sctx, tn.AsOf) + evaluateTS := func(ctx context.Context, sctx sessionctx.Context) (uint64, error) { + return parseAndValidateAsOf(ctx, p.sctx, tn.AsOf) } - stmtAsOfTS, err := evaluateTS(p.sctx) + stmtAsOfTS, err := evaluateTS(p.ctx, p.sctx) if err != nil { return err } @@ -200,7 +200,7 @@ func (p *staleReadProcessor) OnExecutePreparedStmt(preparedTSEvaluator Staleness var stmtTS uint64 if preparedTSEvaluator != nil { // If the `preparedTSEvaluator` is not nil, it means the prepared statement is stale read - if stmtTS, err = preparedTSEvaluator(p.sctx); err != nil { + if stmtTS, err = preparedTSEvaluator(p.ctx, p.sctx); err != nil { return err } } @@ -285,7 +285,7 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as return 0, err } - if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil { + if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts); err != nil { return 0, err } @@ -298,8 +298,8 @@ func getTsEvaluatorFromReadStaleness(sctx sessionctx.Context) StalenessTSEvaluat return nil } - return func(sctx sessionctx.Context) (uint64, error) { - return CalculateTsWithReadStaleness(sctx, readStaleness) + return func(ctx context.Context, sctx sessionctx.Context) (uint64, error) { + return CalculateTsWithReadStaleness(ctx, sctx, readStaleness) } } diff --git a/sessiontxn/staleread/processor_test.go b/sessiontxn/staleread/processor_test.go index 204bb63a3d8de..de336ae933358 100644 --- a/sessiontxn/staleread/processor_test.go +++ b/sessiontxn/staleread/processor_test.go @@ -51,7 +51,7 @@ func (p *staleReadPoint) checkMatchProcessor(t *testing.T, processor staleread.P evaluator := processor.GetStalenessTSEvaluatorForPrepare() if hasEvaluator { require.NotNil(t, evaluator) - ts, err := evaluator(p.tk.Session()) + ts, err := evaluator(context.Background(), p.tk.Session()) require.NoError(t, err) require.Equal(t, processor.GetStalenessReadTS(), ts) } else { @@ -108,6 +108,7 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { tn := astTableWithAsOf(t, "") p1 := genStaleReadPoint(t, tk) p2 := genStaleReadPoint(t, tk) + ctx := context.Background() // create local temporary table to check processor's infoschema will consider temporary table tk.MustExec("create temporary table test.t2(a int)") @@ -157,19 +158,19 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { err = processor.OnSelectTable(tn) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) evaluator := processor.GetStalenessTSEvaluatorForPrepare() - evaluatorTS, err := evaluator(tk.Session()) + evaluatorTS, err := evaluator(ctx, tk.Session()) require.NoError(t, err) require.Equal(t, expectedTS, evaluatorTS) tk.MustExec("set @@tidb_read_staleness=''") tk.MustExec("do sleep(0.01)") - evaluatorTS, err = evaluator(tk.Session()) + evaluatorTS, err = evaluator(ctx, tk.Session()) require.NoError(t, err) - expectedTS2, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS2, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS2, evaluatorTS) @@ -216,11 +217,11 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) { err = processor.OnSelectTable(tn) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) evaluator = processor.GetStalenessTSEvaluatorForPrepare() - evaluatorTS, err = evaluator(tk.Session()) + evaluatorTS, err = evaluator(ctx, tk.Session()) require.NoError(t, err) require.Equal(t, expectedTS, evaluatorTS) tk.MustExec("set @@tidb_read_staleness=''") @@ -233,13 +234,14 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { tk := testkit.NewTestKit(t, store) p1 := genStaleReadPoint(t, tk) //p2 := genStaleReadPoint(t, tk) + ctx := context.Background() // create local temporary table to check processor's infoschema will consider temporary table tk.MustExec("create temporary table test.t2(a int)") // execute prepared stmt with ts evaluator processor := createProcessor(t, tk.Session()) - err := processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err := processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.NoError(t, err) @@ -247,7 +249,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // will get an error when ts evaluator fails processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return 0, errors.New("mock error") }) require.Error(t, err) @@ -272,7 +274,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // prepared ts is not allowed when @@txn_read_ts is set tk.MustExec(fmt.Sprintf("SET TRANSACTION READ ONLY AS OF TIMESTAMP '%s'", p1.dt)) processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.Error(t, err) @@ -285,7 +287,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { err = processor.OnExecutePreparedStmt(nil) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) tk.MustExec("set @@tidb_read_staleness=''") @@ -293,7 +295,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { // `@@tidb_read_staleness` will be ignored when `as of` or `@@tx_read_ts` tk.MustExec("set @@tidb_read_staleness=-5") processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.NoError(t, err) @@ -336,7 +338,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) { err = processor.OnExecutePreparedStmt(nil) require.True(t, processor.IsStaleness()) require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion()) - expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second) + expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second) require.NoError(t, err) require.Equal(t, expectedTS, processor.GetStalenessReadTS()) tk.MustExec("set @@tidb_read_staleness=''") @@ -376,7 +378,7 @@ func TestStaleReadProcessorInTxn(t *testing.T) { // return an error when execute prepared stmt with as of processor = createProcessor(t, tk.Session()) - err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) { + err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) { return p1.ts, nil }) require.Error(t, err) diff --git a/sessiontxn/staleread/util.go b/sessiontxn/staleread/util.go index d2cc7e4863446..30a446bbb1817 100644 --- a/sessiontxn/staleread/util.go +++ b/sessiontxn/staleread/util.go @@ -71,14 +71,25 @@ func CalculateAsOfTsExpr(ctx context.Context, sctx sessionctx.Context, tsExpr as } // CalculateTsWithReadStaleness calculates the TsExpr for readStaleness duration -func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { +func CalculateTsWithReadStaleness(ctx context.Context, sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { nowVal, err := expression.GetStmtTimestamp(sctx) if err != nil { return 0, err } tsVal := nowVal.Add(readStaleness) - minTsVal := expression.GetMinSafeTime(sctx) - return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil + minSafeTSVal := expression.GetMinSafeTime(sctx) + calculatedTime := expression.CalAppropriateTime(tsVal, nowVal, minSafeTSVal) + readTS := oracle.GoTimeToTS(calculatedTime) + if calculatedTime.After(minSafeTSVal) { + // If the final calculated exceeds the min safe ts, we are not sure whether the ts is safe to read (note that + // reading with a ts larger than PD's max allocated ts + 1 is unsafe and may break linearizability). + // So in this case, do an extra check on it. + err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), readTS) + if err != nil { + return 0, err + } + } + return readTS, nil } // IsStmtStaleness indicates whether the current statement is staleness or not