Skip to content

Commit

Permalink
Separate the TSO client implementation
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <ghzpotato@gmail.com>
  • Loading branch information
JmPotato committed Nov 25, 2024
1 parent ec77762 commit 1a6e4ed
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 174 deletions.
48 changes: 31 additions & 17 deletions client/batch_controller.go → client/batch/batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package pd
package batch

import (
"context"
Expand All @@ -24,10 +24,11 @@ import (
// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4).
const defaultBestBatchSize = 8

// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type finisherFunc[T any] func(int, T, error)
// FinisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type FinisherFunc[T any] func(int, T, error)

type batchController[T any] struct {
// Controller is used to batch requests.
type Controller[T any] struct {
maxBatchSize int
// bestBatchSize is a dynamic size that changed based on the current batch effect.
bestBatchSize int
Expand All @@ -36,15 +37,16 @@ type batchController[T any] struct {
collectedRequestCount int

// The finisher function to cancel collected requests when an internal error occurs.
finisher finisherFunc[T]
finisher FinisherFunc[T]
// The observer to record the best batch size.
bestBatchObserver prometheus.Histogram
// The time after getting the first request and the token, and before performing extra batching.
extraBatchingStartTime time.Time
}

func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] {
return &batchController[T]{
// NewController creates a new batch controller.
func NewController[T any](maxBatchSize int, finisher FinisherFunc[T], bestBatchObserver prometheus.Histogram) *Controller[T] {
return &Controller[T]{
maxBatchSize: maxBatchSize,
bestBatchSize: defaultBestBatchSize,
collectedRequests: make([]T, maxBatchSize+1),
Expand All @@ -54,11 +56,11 @@ func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestB
}
}

// fetchPendingRequests will start a new round of the batch collecting from the channel.
// FetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service.
// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled
// when the function returns, so the caller don't need to clear them manually.
func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
var tokenAcquired bool
defer func() {
if errRet != nil {
Expand All @@ -67,7 +69,7 @@ func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestC
if tokenAcquired {
tokenCh <- struct{}{}
}
bc.finishCollectedRequests(bc.finisher, errRet)
bc.FinishCollectedRequests(bc.finisher, errRet)
}
}()

Expand Down Expand Up @@ -167,9 +169,9 @@ fetchPendingRequestsLoop:
return nil
}

// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// FetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// before calling this function.
func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {
func (bc *Controller[T]) FetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {
batchingLoop:
for bc.collectedRequestCount < bc.maxBatchSize {
select {
Expand Down Expand Up @@ -198,17 +200,23 @@ nonWaitingBatchLoop:
return nil
}

func (bc *batchController[T]) pushRequest(req T) {
func (bc *Controller[T]) pushRequest(req T) {
bc.collectedRequests[bc.collectedRequestCount] = req
bc.collectedRequestCount++
}

func (bc *batchController[T]) getCollectedRequests() []T {
// GetCollectedRequests returns the collected requests.
func (bc *Controller[T]) GetCollectedRequests() []T {
return bc.collectedRequests[:bc.collectedRequestCount]
}

// adjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *batchController[T]) adjustBestBatchSize() {
// GetCollectedRequestCount returns the number of collected requests.
func (bc *Controller[T]) GetCollectedRequestCount() int {
return bc.collectedRequestCount
}

// AdjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *Controller[T]) AdjustBestBatchSize() {
if bc.bestBatchObserver != nil {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
}
Expand All @@ -222,7 +230,8 @@ func (bc *batchController[T]) adjustBestBatchSize() {
}
}

func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) {
// FinishCollectedRequests finishes the collected requests.
func (bc *Controller[T]) FinishCollectedRequests(finisher FinisherFunc[T], err error) {
if finisher == nil {
finisher = bc.finisher
}
Expand All @@ -234,3 +243,8 @@ func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T],
// Prevent the finished requests from being processed again.
bc.collectedRequestCount = 0
}

// GetExtraBatchingStartTime returns the extra batching start time.
func (bc *Controller[T]) GetExtraBatchingStartTime() time.Time {
return bc.extraBatchingStartTime
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package pd
package batch

import (
"context"
Expand All @@ -23,26 +23,26 @@ import (

func TestAdjustBestBatchSize(t *testing.T) {
re := require.New(t)
bc := newBatchController[int](20, nil, nil)
bc := NewController[int](20, nil, nil)
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
// Clear the collected requests.
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
// Push 10 requests - do not increase the best batch size.
for i := range 10 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
// Push 15 requests, increase the best batch size.
for i := range 15 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
}

type testRequest struct {
Expand All @@ -52,10 +52,10 @@ type testRequest struct {

func TestFinishCollectedRequests(t *testing.T) {
re := require.New(t)
bc := newBatchController[*testRequest](20, nil, nil)
bc := NewController[*testRequest](20, nil, nil)
// Finish with zero request count.
re.Zero(bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with non-zero request count.
requests := make([]*testRequest, 10)
Expand All @@ -64,14 +64,14 @@ func TestFinishCollectedRequests(t *testing.T) {
bc.pushRequest(requests[i])
}
re.Equal(10, bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with custom finisher.
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) {
bc.FinishCollectedRequests(func(idx int, tr *testRequest, err error) {
tr.idx = idx
tr.err = err
}, context.Canceled)
Expand Down
12 changes: 6 additions & 6 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/pd/client/caller"
"github.com/tikv/pd/client/clients/metastorage"
"github.com/tikv/pd/client/clients/tso"
"github.com/tikv/pd/client/constants"
"github.com/tikv/pd/client/errs"
"github.com/tikv/pd/client/metrics"
Expand Down Expand Up @@ -140,8 +141,7 @@ type RPCClient interface {
// on your needs.
WithCallerComponent(callerComponent caller.Component) RPCClient

// TSOClient is the TSO client.
TSOClient
tso.Client
metastorage.Client
// KeyspaceClient manages keyspace metadata.
KeyspaceClient
Expand Down Expand Up @@ -179,7 +179,7 @@ type serviceModeKeeper struct {
// triggering service mode switching concurrently.
sync.RWMutex
serviceMode pdpb.ServiceMode
tsoClient *tsoClient
tsoClient *tso.Cli
tsoSvcDiscovery sd.ServiceDiscovery
}

Expand All @@ -191,7 +191,7 @@ func (k *serviceModeKeeper) close() {
k.tsoSvcDiscovery.Close()
fallthrough
case pdpb.ServiceMode_PD_SVC_MODE:
k.tsoClient.close()
k.tsoClient.Close()
case pdpb.ServiceMode_UNKNOWN_SVC_MODE:
}
}
Expand Down Expand Up @@ -557,7 +557,7 @@ func (c *client) getClientAndContext(ctx context.Context) (pdpb.PDClient, contex
}

// GetTSAsync implements the TSOClient interface.
func (c *client) GetTSAsync(ctx context.Context) TSFuture {
func (c *client) GetTSAsync(ctx context.Context) tso.TSFuture {
defer trace.StartRegion(ctx, "pdclient.GetTSAsync").End()
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span = span.Tracer().StartSpan("pdclient.GetTSAsync", opentracing.ChildOf(span.Context()))
Expand All @@ -570,7 +570,7 @@ func (c *client) GetTSAsync(ctx context.Context) TSFuture {
//
// Deprecated: Local TSO will be completely removed in the future. Currently, regardless of the
// parameters passed in, this method will default to returning the global TSO.
func (c *client) GetLocalTSAsync(ctx context.Context, _ string) TSFuture {
func (c *client) GetLocalTSAsync(ctx context.Context, _ string) tso.TSFuture {
return c.GetTSAsync(ctx)
}

Expand Down
28 changes: 0 additions & 28 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"testing"
"time"

"github.com/pingcap/errors"
"github.com/stretchr/testify/require"
"github.com/tikv/pd/client/caller"
"github.com/tikv/pd/client/opt"
Expand Down Expand Up @@ -62,30 +61,3 @@ func TestClientWithRetry(t *testing.T) {
re.Error(err)
re.Less(time.Since(start), time.Second*10)
}

func TestTsoRequestWait(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
req := &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
requestCtx: context.TODO(),
clientCtx: ctx,
}
cancel()
_, _, err := req.Wait()
re.ErrorIs(errors.Cause(err), context.Canceled)

ctx, cancel = context.WithCancel(context.Background())
req = &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
requestCtx: ctx,
clientCtx: context.TODO(),
}
cancel()
_, _, err = req.Wait()
re.ErrorIs(errors.Cause(err), context.Canceled)
}
Loading

0 comments on commit 1a6e4ed

Please sign in to comment.