Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tso: merge lastTS and lastArrivalTS into an atomic pointer #1054

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,8 @@ func NewEmptyPDOracle() oracle.Oracle {
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lastTSPointer := lastTSInterface.(*uint64)
atomic.StoreUint64(lastTSPointer, ts)
lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lasTSArrivalPointer := lasTSArrivalInterface.(*uint64)
atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000))
}
setEmptyPDOracleLastArrivalTs(oc, ts)
}

// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
}
}
82 changes: 35 additions & 47 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ const slowDist = 30 * time.Millisecond
// pdOracle is an Oracle that uses a placement driver client as source.
type pdOracle struct {
c pd.Client
// txn_scope (string) -> lastTSPointer (*uint64)
// txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO])
lastTSMap sync.Map
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
lastArrivalTSMap sync.Map
quit chan struct{}
quit chan struct{}
}

// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched.
type lastTSO struct {
tso uint64
arrival uint64
}

// NewPdOracle create an Oracle that uses a pd client source.
Expand Down Expand Up @@ -163,63 +167,51 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
break
}
}
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
}

func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
current := &lastTSO{
tso: ts,
arrival: o.getArrivalTimestamp(),
}
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
pointer := &atomic.Pointer[lastTSO]{}
pointer.Store(current)
// do not handle the stored case, because it only runs once.
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer)
}
lastTSPointer := lastTSInterface.(*uint64)
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
last := lastTSPointer.Load()
if current.tso <= last.tso || current.arrival <= last.arrival {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
if lastTSPointer.CompareAndSwap(last, current) {
return
}
}
}

func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
last, exist := o.getLastTSWithArrivalTS(txnScope)
if !exist {
return 0, false
}
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
return last.tso, true
}

func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
return 0, false
return nil, false
}
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
last := lastTSPointer.Load()
if last == nil {
return nil, false
}
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
return last, true
}

func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
Expand Down Expand Up @@ -293,22 +285,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
}

func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(txnScope)
last, ok := o.getLastTSWithArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
}
arrivalTS, ok := o.getLastArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
}
ts, arrivalTS := last.tso, last.arrival
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
physicalTime := oracle.GetTimeFromTS(ts)
if uint64(physicalTime.Unix()) <= prevSecond {
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
}

staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))

staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime))
return oracle.GoTimeToTS(staleTime), nil
}

Expand Down
32 changes: 32 additions & 0 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
assert.NotNil(t, err)
assert.Regexp(t, ".*invalid prevSecond.*", err.Error())
}

func TestNonFutureStaleTSO(t *testing.T) {
o := oracles.NewEmptyPDOracle()
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now()))
for i := 0; i < 100; i++ {
time.Sleep(10 * time.Millisecond)
now := time.Now()
upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift

closeCh := make(chan struct{})
go func() {
time.Sleep(100 * time.Microsecond)
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now))
close(closeCh)
}()
CHECK:
for {
select {
case <-closeCh:
break CHECK
default:
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0)
assert.Nil(t, err)
staleTime := oracle.GetTimeFromTS(ts)
if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ {
assert.Less(t, staleTime, upperBound, i)
t.FailNow()
}
}
}
}
}
Loading