diff --git a/client/clients/tso/dispatcher.go b/client/clients/tso/dispatcher.go index c05ab27d755..1cc2b2aa940 100644 --- a/client/clients/tso/dispatcher.go +++ b/client/clients/tso/dispatcher.go @@ -36,33 +36,12 @@ import ( "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/batch" cctx "github.com/tikv/pd/client/pkg/connectionctx" + "github.com/tikv/pd/client/pkg/deadline" "github.com/tikv/pd/client/pkg/retry" - "github.com/tikv/pd/client/pkg/utils/timerutil" "github.com/tikv/pd/client/pkg/utils/tsoutil" sd "github.com/tikv/pd/client/servicediscovery" ) -// deadline is used to control the TS request timeout manually, -// it will be sent to the `tsDeadlineCh` to be handled by the `watchTSDeadline` goroutine. -type deadline struct { - timer *time.Timer - done chan struct{} - cancel context.CancelFunc -} - -func newTSDeadline( - timeout time.Duration, - done chan struct{}, - cancel context.CancelFunc, -) *deadline { - timer := timerutil.GlobalTimerPool.Get(timeout) - return &deadline{ - timer: timer, - done: done, - cancel: cancel, - } -} - type tsoInfo struct { tsoServer string reqKeyspaceGroupID uint32 @@ -86,10 +65,10 @@ type tsoDispatcher struct { ctx context.Context cancel context.CancelFunc - provider tsoServiceProvider - tsoRequestCh chan *Request - tsDeadlineCh chan *deadline - latestTSOInfo atomic.Pointer[tsoInfo] + provider tsoServiceProvider + tsoRequestCh chan *Request + deadlineWatcher *deadline.Watcher + latestTSOInfo atomic.Pointer[tsoInfo] // For reusing `*batchController` objects batchBufferPool *sync.Pool @@ -119,11 +98,11 @@ func newTSODispatcher( tokenCh := make(chan struct{}, tokenChCapacity) td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - provider: provider, - tsoRequestCh: tsoRequestCh, - tsDeadlineCh: make(chan *deadline, tokenChCapacity), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + provider: provider, + tsoRequestCh: tsoRequestCh, + deadlineWatcher: deadline.NewWatcher(dispatcherCtx, tokenChCapacity, "tso"), batchBufferPool: &sync.Pool{ New: func() any { return batch.NewController[*Request]( @@ -135,34 +114,9 @@ func newTSODispatcher( }, tokenCh: tokenCh, } - go td.watchTSDeadline() return td } -func (td *tsoDispatcher) watchTSDeadline() { - log.Info("[tso] start tso deadline watcher") - defer log.Info("[tso] exit tso deadline watcher") - for { - select { - case d := <-td.tsDeadlineCh: - select { - case <-d.timer.C: - log.Error("[tso] tso request is canceled due to timeout", - errs.ZapError(errs.ErrClientGetTSOTimeout)) - d.cancel() - timerutil.GlobalTimerPool.Put(d.timer) - case <-d.done: - timerutil.GlobalTimerPool.Put(d.timer) - case <-td.ctx.Done(): - timerutil.GlobalTimerPool.Put(d.timer) - return - } - case <-td.ctx.Done(): - return - } - } -} - func (td *tsoDispatcher) revokePendingRequests(err error) { for range len(td.tsoRequestCh) { req := <-td.tsoRequestCh @@ -378,14 +332,11 @@ tsoBatchLoop: } } - done := make(chan struct{}) - dl := newTSDeadline(option.Timeout, done, cancel) - select { - case <-ctx.Done(): + done := td.deadlineWatcher.Start(ctx, option.Timeout, cancel) + if done == nil { // Finish the collected requests if the context is canceled. td.cancelCollectedRequests(tsoBatchController, invalidStreamID, errors.WithStack(ctx.Err())) return - case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. err = td.processRequests(stream, tsoBatchController, done) diff --git a/client/errs/errno.go b/client/errs/errno.go index 25665f01017..99a426d0776 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -56,7 +56,6 @@ var ( ErrClientGetMetaStorageClient = errors.Normalize("failed to get meta storage client", errors.RFCCodeText("PD:client:ErrClientGetMetaStorageClient")) ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) ErrClientTSOStreamClosed = errors.Normalize("encountered TSO stream being closed unexpectedly", errors.RFCCodeText("PD:client:ErrClientTSOStreamClosed")) - ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetMinTSO = errors.Normalize("get min TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetMinTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) diff --git a/client/pkg/deadline/watcher.go b/client/pkg/deadline/watcher.go new file mode 100644 index 00000000000..b40857edbfd --- /dev/null +++ b/client/pkg/deadline/watcher.go @@ -0,0 +1,111 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadline + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/pingcap/log" + + "github.com/tikv/pd/client/pkg/utils/timerutil" +) + +// The `cancel` function will be invoked once the specified `timeout` elapses without receiving a `done` signal. +type deadline struct { + timer *time.Timer + done chan struct{} + cancel context.CancelFunc +} + +// Watcher is used to watch and manage the deadlines. +type Watcher struct { + ctx context.Context + source string + Ch chan *deadline +} + +// NewWatcher is used to create a new deadline watcher. +func NewWatcher(ctx context.Context, capacity int, source string) *Watcher { + watcher := &Watcher{ + ctx: ctx, + source: source, + Ch: make(chan *deadline, capacity), + } + go watcher.Watch() + return watcher +} + +// Watch is used to watch the deadlines and invoke the `cancel` function when the deadline is reached. +// The `err` will be returned if the deadline is reached. +func (w *Watcher) Watch() { + log.Info("[pd] start the deadline watcher", zap.String("source", w.source)) + defer log.Info("[pd] exit the deadline watcher", zap.String("source", w.source)) + for { + select { + case d := <-w.Ch: + select { + case <-d.timer.C: + log.Error("[pd] the deadline is reached", zap.String("source", w.source)) + d.cancel() + timerutil.GlobalTimerPool.Put(d.timer) + case <-d.done: + timerutil.GlobalTimerPool.Put(d.timer) + case <-w.ctx.Done(): + timerutil.GlobalTimerPool.Put(d.timer) + return + } + case <-w.ctx.Done(): + return + } + } +} + +// Start is used to start a deadline. It returns a channel which will be closed when the deadline is reached. +// Returns nil if the deadline is not started. +func (w *Watcher) Start( + ctx context.Context, + timeout time.Duration, + cancel context.CancelFunc, +) chan struct{} { + // Check if the watcher is already canceled. + select { + case <-w.ctx.Done(): + return nil + case <-ctx.Done(): + return nil + default: + } + // Initialize the deadline. + timer := timerutil.GlobalTimerPool.Get(timeout) + d := &deadline{ + timer: timer, + done: make(chan struct{}), + cancel: cancel, + } + // Send the deadline to the watcher. + select { + case <-w.ctx.Done(): + timerutil.GlobalTimerPool.Put(timer) + return nil + case <-ctx.Done(): + timerutil.GlobalTimerPool.Put(timer) + return nil + case w.Ch <- d: + return d.done + } +} diff --git a/client/pkg/deadline/watcher_test.go b/client/pkg/deadline/watcher_test.go new file mode 100644 index 00000000000..b93987b8874 --- /dev/null +++ b/client/pkg/deadline/watcher_test.go @@ -0,0 +1,58 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadline + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWatcher(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher := NewWatcher(ctx, 10, "test") + var deadlineReached atomic.Bool + done := watcher.Start(ctx, time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.NotNil(done) + time.Sleep(5 * time.Millisecond) + re.True(deadlineReached.Load()) + + deadlineReached.Store(false) + done = watcher.Start(ctx, 500*time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.NotNil(done) + done <- struct{}{} + time.Sleep(time.Second) + re.False(deadlineReached.Load()) + + deadCtx, deadCancel := context.WithCancel(ctx) + deadCancel() + deadlineReached.Store(false) + done = watcher.Start(deadCtx, time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.Nil(done) + time.Sleep(5 * time.Millisecond) + re.False(deadlineReached.Load()) +} diff --git a/errors.toml b/errors.toml index 2ab3b014f5a..9980a98ab14 100644 --- a/errors.toml +++ b/errors.toml @@ -131,11 +131,6 @@ error = ''' get TSO failed ''' -["PD:client:ErrClientGetTSOTimeout"] -error = ''' -get TSO timeout -''' - ["PD:cluster:ErrInvalidStoreID"] error = ''' invalid store id %d, not found diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index ee24b4d0673..834bf4f824e 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -144,7 +144,6 @@ var ( // client errors var ( ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) - ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) ErrClientGetTSO = errors.Normalize("get TSO failed", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember"))