Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: separate the TSO client implementation #8848

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions client/batch_controller.go → client/batch/batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package pd
package batch

import (
"context"
Expand All @@ -24,10 +24,11 @@
// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4).
const defaultBestBatchSize = 8

// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type finisherFunc[T any] func(int, T, error)
// FinisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type FinisherFunc[T any] func(int, T, error)

type batchController[T any] struct {
// Controller is used to batch requests.
type Controller[T any] struct {
maxBatchSize int
// bestBatchSize is a dynamic size that changed based on the current batch effect.
bestBatchSize int
Expand All @@ -36,15 +37,16 @@
collectedRequestCount int

// The finisher function to cancel collected requests when an internal error occurs.
finisher finisherFunc[T]
finisher FinisherFunc[T]
// The observer to record the best batch size.
bestBatchObserver prometheus.Histogram
// The time after getting the first request and the token, and before performing extra batching.
extraBatchingStartTime time.Time
}

func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] {
return &batchController[T]{
// NewController creates a new batch controller.
func NewController[T any](maxBatchSize int, finisher FinisherFunc[T], bestBatchObserver prometheus.Histogram) *Controller[T] {
return &Controller[T]{
maxBatchSize: maxBatchSize,
bestBatchSize: defaultBestBatchSize,
collectedRequests: make([]T, maxBatchSize+1),
Expand All @@ -54,11 +56,11 @@
}
}

// fetchPendingRequests will start a new round of the batch collecting from the channel.
// FetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service.
// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled
// when the function returns, so the caller don't need to clear them manually.
func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
var tokenAcquired bool
defer func() {
if errRet != nil {
Expand All @@ -67,7 +69,7 @@
if tokenAcquired {
tokenCh <- struct{}{}
}
bc.finishCollectedRequests(bc.finisher, errRet)
bc.FinishCollectedRequests(bc.finisher, errRet)
}
}()

Expand Down Expand Up @@ -167,9 +169,9 @@
return nil
}

// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// FetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// before calling this function.
func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {
func (bc *Controller[T]) FetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {

Check warning on line 174 in client/batch/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch/batch_controller.go#L174

Added line #L174 was not covered by tests
batchingLoop:
for bc.collectedRequestCount < bc.maxBatchSize {
select {
Expand Down Expand Up @@ -198,17 +200,23 @@
return nil
}

func (bc *batchController[T]) pushRequest(req T) {
func (bc *Controller[T]) pushRequest(req T) {
bc.collectedRequests[bc.collectedRequestCount] = req
bc.collectedRequestCount++
}

func (bc *batchController[T]) getCollectedRequests() []T {
// GetCollectedRequests returns the collected requests.
func (bc *Controller[T]) GetCollectedRequests() []T {
return bc.collectedRequests[:bc.collectedRequestCount]
}

// adjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *batchController[T]) adjustBestBatchSize() {
// GetCollectedRequestCount returns the number of collected requests.
func (bc *Controller[T]) GetCollectedRequestCount() int {
return bc.collectedRequestCount
}

// AdjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *Controller[T]) AdjustBestBatchSize() {
if bc.bestBatchObserver != nil {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
}
Expand All @@ -222,7 +230,8 @@
}
}

func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) {
// FinishCollectedRequests finishes the collected requests.
func (bc *Controller[T]) FinishCollectedRequests(finisher FinisherFunc[T], err error) {
if finisher == nil {
finisher = bc.finisher
}
Expand All @@ -234,3 +243,8 @@
// Prevent the finished requests from being processed again.
bc.collectedRequestCount = 0
}

// GetExtraBatchingStartTime returns the extra batching start time.
func (bc *Controller[T]) GetExtraBatchingStartTime() time.Time {
return bc.extraBatchingStartTime
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package pd
package batch

import (
"context"
Expand All @@ -23,26 +23,26 @@ import (

func TestAdjustBestBatchSize(t *testing.T) {
re := require.New(t)
bc := newBatchController[int](20, nil, nil)
bc := NewController[int](20, nil, nil)
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
// Clear the collected requests.
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
// Push 10 requests - do not increase the best batch size.
for i := range 10 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
// Push 15 requests, increase the best batch size.
for i := range 15 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
bc.AdjustBestBatchSize()
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
}

type testRequest struct {
Expand All @@ -52,10 +52,10 @@ type testRequest struct {

func TestFinishCollectedRequests(t *testing.T) {
re := require.New(t)
bc := newBatchController[*testRequest](20, nil, nil)
bc := NewController[*testRequest](20, nil, nil)
// Finish with zero request count.
re.Zero(bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with non-zero request count.
requests := make([]*testRequest, 10)
Expand All @@ -64,14 +64,14 @@ func TestFinishCollectedRequests(t *testing.T) {
bc.pushRequest(requests[i])
}
re.Equal(10, bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
bc.FinishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with custom finisher.
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) {
bc.FinishCollectedRequests(func(idx int, tr *testRequest, err error) {
tr.idx = idx
tr.err = err
}, context.Canceled)
Expand Down
12 changes: 6 additions & 6 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/pd/client/caller"
"github.com/tikv/pd/client/clients/metastorage"
"github.com/tikv/pd/client/clients/tso"
"github.com/tikv/pd/client/constants"
"github.com/tikv/pd/client/errs"
"github.com/tikv/pd/client/metrics"
Expand Down Expand Up @@ -140,8 +141,7 @@
// on your needs.
WithCallerComponent(callerComponent caller.Component) RPCClient

// TSOClient is the TSO client.
TSOClient
tso.Client
metastorage.Client
// KeyspaceClient manages keyspace metadata.
KeyspaceClient
Expand Down Expand Up @@ -179,7 +179,7 @@
// triggering service mode switching concurrently.
sync.RWMutex
serviceMode pdpb.ServiceMode
tsoClient *tsoClient
tsoClient *tso.Cli
tsoSvcDiscovery sd.ServiceDiscovery
}

Expand All @@ -191,7 +191,7 @@
k.tsoSvcDiscovery.Close()
fallthrough
case pdpb.ServiceMode_PD_SVC_MODE:
k.tsoClient.close()
k.tsoClient.Close()
case pdpb.ServiceMode_UNKNOWN_SVC_MODE:
}
}
Expand Down Expand Up @@ -557,7 +557,7 @@
}

// GetTSAsync implements the TSOClient interface.
func (c *client) GetTSAsync(ctx context.Context) TSFuture {
func (c *client) GetTSAsync(ctx context.Context) tso.TSFuture {
defer trace.StartRegion(ctx, "pdclient.GetTSAsync").End()
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span = span.Tracer().StartSpan("pdclient.GetTSAsync", opentracing.ChildOf(span.Context()))
Expand All @@ -570,7 +570,7 @@
//
// Deprecated: Local TSO will be completely removed in the future. Currently, regardless of the
// parameters passed in, this method will default to returning the global TSO.
func (c *client) GetLocalTSAsync(ctx context.Context, _ string) TSFuture {
func (c *client) GetLocalTSAsync(ctx context.Context, _ string) tso.TSFuture {

Check warning on line 573 in client/client.go

View check run for this annotation

Codecov / codecov/patch

client/client.go#L573

Added line #L573 was not covered by tests
return c.GetTSAsync(ctx)
}

Expand Down
28 changes: 0 additions & 28 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"testing"
"time"

"github.com/pingcap/errors"
"github.com/stretchr/testify/require"
"github.com/tikv/pd/client/caller"
"github.com/tikv/pd/client/opt"
Expand Down Expand Up @@ -62,30 +61,3 @@ func TestClientWithRetry(t *testing.T) {
re.Error(err)
re.Less(time.Since(start), time.Second*10)
}

func TestTsoRequestWait(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
req := &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
requestCtx: context.TODO(),
clientCtx: ctx,
}
cancel()
_, _, err := req.Wait()
re.ErrorIs(errors.Cause(err), context.Canceled)

ctx, cancel = context.WithCancel(context.Background())
req = &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
requestCtx: ctx,
clientCtx: context.TODO(),
}
cancel()
_, _, err = req.Wait()
re.ErrorIs(errors.Cause(err), context.Canceled)
}
Loading