Skip to content

Commit

Permalink
*: Use strict validation for stale read ts & flashback ts (pingcap#57050
Browse files Browse the repository at this point in the history
  • Loading branch information
MyonKeminta committed Dec 12, 2024
1 parent 2061937 commit 9296a17
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 89 deletions.
13 changes: 4 additions & 9 deletions pkg/ddl/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/meta"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand Down Expand Up @@ -112,16 +111,12 @@ func getStoreGlobalMinSafeTS(s kv.Storage) time.Time {

// ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS).
func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error {
currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0)
// If we fail to calculate currentTS from local time, fallback to get a timestamp from PD.
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate flashback timestamp: %v", err)
}
currentTS = currentVer.Ver
return errors.Errorf("fail to validate flashback timestamp: %v", err)
}
currentTS := currentVer.Ver

oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS)
if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) {
return errors.Errorf("cannot set flashback timestamp to future time")
Expand Down
6 changes: 2 additions & 4 deletions pkg/executor/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres
newSnapshotTS := getSnapshotTSByName()
newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS
if newSnapshotIsSet {
if name == variable.TiDBTxnReadTS {
err = sessionctx.ValidateStaleReadTS(ctx, e.Ctx(), newSnapshotTS)
} else {
err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx(), newSnapshotTS)
err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx().GetStore(), newSnapshotTS)
if name != variable.TiDBTxnReadTS {
// Also check gc safe point for snapshot read.
// We don't check snapshot with gc safe point for read_ts
// Client-go will automatically check the snapshotTS with gc safe point. It's unnecessary to check gc safe point during set executor.
Expand Down
25 changes: 21 additions & 4 deletions pkg/executor/stale_txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package executor_test
import (
"context"
"fmt"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -1409,14 +1410,30 @@ func TestStaleTSO(t *testing.T) {
tk.MustExec("create table t (id int)")

tk.MustExec("insert into t values(1)")
ts1, err := strconv.ParseUint(tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows()[0][0].(string), 10, 64)
require.NoError(t, err)

// Wait until the physical advances for 1s
var currentTS uint64
for {
tk.MustExec("begin")
currentTS, err = strconv.ParseUint(tk.MustQuery("select @@tidb_current_ts").Rows()[0][0].(string), 10, 64)
require.NoError(t, err)
tk.MustExec("rollback")
if oracle.GetTimeFromTS(currentTS).After(oracle.GetTimeFromTS(ts1).Add(time.Second)) {
break
}
time.Sleep(time.Millisecond * 100)
}

asOfExprs := []string{
"now(3) - interval 1 second",
"current_time() - interval 1 second",
"curtime() - interval 1 second",
"now(3) - interval 10 second",
"current_time() - interval 10 second",
"curtime() - interval 10 second",
}

nextTSO := oracle.GoTimeToTS(time.Now().Add(2 * time.Second))
nextPhysical := oracle.GetPhysical(oracle.GetTimeFromTS(currentTS).Add(10 * time.Second))
nextTSO := oracle.ComposeTS(nextPhysical, oracle.ExtractLogical(currentTS))
require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/pkg/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO)))
defer failpoint.Disable("github.com/pingcap/tidb/pkg/sessiontxn/staleread/mockStaleReadTSO")
for _, expr := range asOfExprs {
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ type PlanCacheStmt struct {
SQLDigest *parser.Digest
PlanDigest *parser.Digest
ForUpdateRead bool
SnapshotTSEvaluator func(sessionctx.Context) (uint64, error)
SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error)
NormalizedSQL4PC string
SQLDigest4PC string

Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3693,7 +3693,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan,
if err != nil {
return nil, err
}
if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil {
if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil {
return nil, err
}
p.StaleTxnStartTS = startTS
Expand All @@ -3707,7 +3707,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan,
if err != nil {
return nil, err
}
if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil {
if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil {
return nil, err
}
p.StaleTxnStartTS = startTS
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ var _ = PreprocessorReturn{}.initedLastSnapshotTS
type PreprocessorReturn struct {
initedLastSnapshotTS bool
IsStaleness bool
SnapshotTSEvaluator func(sessionctx.Context) (uint64, error)
SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error)
// LastSnapshotTS is the last evaluated snapshotTS if any
// otherwise it defaults to zero
LastSnapshotTS uint64
Expand Down
43 changes: 2 additions & 41 deletions pkg/sessionctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ package sessionctx
import (
"context"
"fmt"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/sessionctx/sessionstates"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand Down Expand Up @@ -217,44 +214,8 @@ const (
)

// ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp
func ValidateSnapshotReadTS(ctx context.Context, sctx Context, readTS uint64) error {
latestTS, err := sctx.GetStore().GetOracle().GetLowResolutionTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
// If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check
if err != nil || readTS > latestTS {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if readTS > currentVer.Ver {
return errors.Errorf("cannot set read timestamp to a future time")
}
}
return nil
}

// How far future from now ValidateStaleReadTS allows at most
const allowedTimeFromNow = 100 * time.Millisecond

// ValidateStaleReadTS validates that readTS does not exceed the current time not strictly.
func ValidateStaleReadTS(ctx context.Context, sctx Context, readTS uint64) error {
currentTS, err := sctx.GetSessionVars().StmtCtx.GetStaleTSO()
if currentTS == 0 || err != nil {
currentTS, err = sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0)
}
// If we fail to calculate currentTS from local time, fallback to get a timestamp from PD
if err != nil {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
currentTS = currentVer.Ver
}
if oracle.GetTimeFromTS(readTS).After(oracle.GetTimeFromTS(currentTS).Add(allowedTimeFromNow)) {
return errors.Errorf("cannot set read timestamp to a future time")
}
return nil
func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64) error {
return store.GetOracle().ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}

// SysProcTracker is used to track background sys processes
Expand Down
20 changes: 10 additions & 10 deletions pkg/sessiontxn/staleread/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
var _ Processor = &staleReadProcessor{}

// StalenessTSEvaluator is a function to get staleness ts
type StalenessTSEvaluator func(sctx sessionctx.Context) (uint64, error)
type StalenessTSEvaluator func(ctx context.Context, sctx sessionctx.Context) (uint64, error)

// Processor is an interface used to process stale read
type Processor interface {
Expand Down Expand Up @@ -100,7 +100,7 @@ func (p *baseProcessor) setEvaluatedTS(ts uint64) (err error) {
return err
}

return p.setEvaluatedValues(ts, is, func(sctx sessionctx.Context) (uint64, error) {
return p.setEvaluatedValues(ts, is, func(_ context.Context, sctx sessionctx.Context) (uint64, error) {
return ts, nil
})
}
Expand All @@ -116,7 +116,7 @@ func (p *baseProcessor) setEvaluatedTSWithoutEvaluator(ts uint64) (err error) {
}

func (p *baseProcessor) setEvaluatedEvaluator(evaluator StalenessTSEvaluator) error {
ts, err := evaluator(p.sctx)
ts, err := evaluator(p.ctx, p.sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -167,10 +167,10 @@ func (p *staleReadProcessor) OnSelectTable(tn *ast.TableName) error {
}

// If `stmtAsOfTS` is not 0, it means we use 'select ... from xxx as of timestamp ...'
evaluateTS := func(sctx sessionctx.Context) (uint64, error) {
return parseAndValidateAsOf(context.Background(), p.sctx, tn.AsOf)
evaluateTS := func(ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return parseAndValidateAsOf(ctx, p.sctx, tn.AsOf)
}
stmtAsOfTS, err := evaluateTS(p.sctx)
stmtAsOfTS, err := evaluateTS(p.ctx, p.sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (p *staleReadProcessor) OnExecutePreparedStmt(preparedTSEvaluator Staleness
var stmtTS uint64
if preparedTSEvaluator != nil {
// If the `preparedTSEvaluator` is not nil, it means the prepared statement is stale read
if stmtTS, err = preparedTSEvaluator(p.sctx); err != nil {
if stmtTS, err = preparedTSEvaluator(p.ctx, p.sctx); err != nil {
return err
}
}
Expand Down Expand Up @@ -285,7 +285,7 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as
return 0, err
}

if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil {
if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts); err != nil {
return 0, err
}

Expand All @@ -298,8 +298,8 @@ func getTsEvaluatorFromReadStaleness(sctx sessionctx.Context) StalenessTSEvaluat
return nil
}

return func(sctx sessionctx.Context) (uint64, error) {
return CalculateTsWithReadStaleness(sctx, readStaleness)
return func(ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return CalculateTsWithReadStaleness(ctx, sctx, readStaleness)
}
}

Expand Down
30 changes: 16 additions & 14 deletions pkg/sessiontxn/staleread/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (p *staleReadPoint) checkMatchProcessor(t *testing.T, processor staleread.P
evaluator := processor.GetStalenessTSEvaluatorForPrepare()
if hasEvaluator {
require.NotNil(t, evaluator)
ts, err := evaluator(p.tk.Session())
ts, err := evaluator(context.Background(), p.tk.Session())
require.NoError(t, err)
require.Equal(t, processor.GetStalenessReadTS(), ts)
} else {
Expand Down Expand Up @@ -108,6 +108,7 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) {
tn := astTableWithAsOf(t, "")
p1 := genStaleReadPoint(t, tk)
p2 := genStaleReadPoint(t, tk)
ctx := context.Background()

// create local temporary table to check processor's infoschema will consider temporary table
tk.MustExec("create temporary table test.t2(a int)")
Expand Down Expand Up @@ -157,19 +158,19 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) {
err = processor.OnSelectTable(tn)
require.True(t, processor.IsStaleness())
require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion())
expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -100*time.Second)
expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second)
require.NoError(t, err)
require.Equal(t, expectedTS, processor.GetStalenessReadTS())
evaluator := processor.GetStalenessTSEvaluatorForPrepare()
evaluatorTS, err := evaluator(tk.Session())
evaluatorTS, err := evaluator(ctx, tk.Session())
require.NoError(t, err)
require.Equal(t, expectedTS, evaluatorTS)
tk.MustExec("set @@tidb_read_staleness=''")

tk.MustExec("do sleep(0.01)")
evaluatorTS, err = evaluator(tk.Session())
evaluatorTS, err = evaluator(ctx, tk.Session())
require.NoError(t, err)
expectedTS2, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -100*time.Second)
expectedTS2, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second)
require.NoError(t, err)
require.Equal(t, expectedTS2, evaluatorTS)

Expand Down Expand Up @@ -216,11 +217,11 @@ func TestStaleReadProcessorWithSelectTable(t *testing.T) {
err = processor.OnSelectTable(tn)
require.True(t, processor.IsStaleness())
require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion())
expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second)
expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second)
require.NoError(t, err)
require.Equal(t, expectedTS, processor.GetStalenessReadTS())
evaluator = processor.GetStalenessTSEvaluatorForPrepare()
evaluatorTS, err = evaluator(tk.Session())
evaluatorTS, err = evaluator(ctx, tk.Session())
require.NoError(t, err)
require.Equal(t, expectedTS, evaluatorTS)
tk.MustExec("set @@tidb_read_staleness=''")
Expand All @@ -233,21 +234,22 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) {
tk := testkit.NewTestKit(t, store)
p1 := genStaleReadPoint(t, tk)
//p2 := genStaleReadPoint(t, tk)
ctx := context.Background()

// create local temporary table to check processor's infoschema will consider temporary table
tk.MustExec("create temporary table test.t2(a int)")

// execute prepared stmt with ts evaluator
processor := createProcessor(t, tk.Session())
err := processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) {
err := processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return p1.ts, nil
})
require.NoError(t, err)
p1.checkMatchProcessor(t, processor, true)

// will get an error when ts evaluator fails
processor = createProcessor(t, tk.Session())
err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) {
err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return 0, errors.New("mock error")
})
require.Error(t, err)
Expand All @@ -272,7 +274,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) {
// prepared ts is not allowed when @@txn_read_ts is set
tk.MustExec(fmt.Sprintf("SET TRANSACTION READ ONLY AS OF TIMESTAMP '%s'", p1.dt))
processor = createProcessor(t, tk.Session())
err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) {
err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return p1.ts, nil
})
require.Error(t, err)
Expand All @@ -285,15 +287,15 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) {
err = processor.OnExecutePreparedStmt(nil)
require.True(t, processor.IsStaleness())
require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion())
expectedTS, err := staleread.CalculateTsWithReadStaleness(tk.Session(), -100*time.Second)
expectedTS, err := staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -100*time.Second)
require.NoError(t, err)
require.Equal(t, expectedTS, processor.GetStalenessReadTS())
tk.MustExec("set @@tidb_read_staleness=''")

// `@@tidb_read_staleness` will be ignored when `as of` or `@@tx_read_ts`
tk.MustExec("set @@tidb_read_staleness=-100")
processor = createProcessor(t, tk.Session())
err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) {
err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return p1.ts, nil
})
require.NoError(t, err)
Expand Down Expand Up @@ -336,7 +338,7 @@ func TestStaleReadProcessorWithExecutePreparedStmt(t *testing.T) {
err = processor.OnExecutePreparedStmt(nil)
require.True(t, processor.IsStaleness())
require.Equal(t, int64(0), processor.GetStalenessInfoSchema().SchemaMetaVersion())
expectedTS, err = staleread.CalculateTsWithReadStaleness(tk.Session(), -5*time.Second)
expectedTS, err = staleread.CalculateTsWithReadStaleness(ctx, tk.Session(), -5*time.Second)
require.NoError(t, err)
require.Equal(t, expectedTS, processor.GetStalenessReadTS())
tk.MustExec("set @@tidb_read_staleness=''")
Expand Down Expand Up @@ -376,7 +378,7 @@ func TestStaleReadProcessorInTxn(t *testing.T) {

// return an error when execute prepared stmt with as of
processor = createProcessor(t, tk.Session())
err = processor.OnExecutePreparedStmt(func(sctx sessionctx.Context) (uint64, error) {
err = processor.OnExecutePreparedStmt(func(_ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return p1.ts, nil
})
require.Error(t, err)
Expand Down
Loading

0 comments on commit 9296a17

Please sign in to comment.