From e8bef5f67ce17b982c8347f9ea9d92b401273448 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 17 Dec 2024 14:49:54 +0800 Subject: [PATCH] client/batch: allow tokenCh of batch controller to be nil (#8903) ref tikv/pd#8690 Allow `tokenCh` of batch controller be nil. Signed-off-by: JmPotato --- client/pkg/batch/batch_controller.go | 37 +++++++++----- client/pkg/batch/batch_controller_test.go | 61 ++++++++++++++++++++++- 2 files changed, 83 insertions(+), 15 deletions(-) diff --git a/client/pkg/batch/batch_controller.go b/client/pkg/batch/batch_controller.go index 32f0aaba1ae..322502b754a 100644 --- a/client/pkg/batch/batch_controller.go +++ b/client/pkg/batch/batch_controller.go @@ -60,13 +60,18 @@ func NewController[T any](maxBatchSize int, finisher FinisherFunc[T], bestBatchO // It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. // It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled // when the function returns, so the caller don't need to clear them manually. +// `tokenCh` is an optional parameter: +// - If it's nil, the batching process will not wait for the token to arrive to continue. +// - If it's not nil, the batching process will wait for a token to arrive before continuing. +// The token will be given back if any error occurs, otherwise it's the caller's responsibility +// to decide when to recycle the signal. func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { var tokenAcquired bool defer func() { if errRet != nil { // Something went wrong when collecting a batch of requests. Release the token and cancel collected requests // if any. - if tokenAcquired { + if tokenAcquired && tokenCh != nil { tokenCh <- struct{}{} } bc.FinishCollectedRequests(bc.finisher, errRet) @@ -80,6 +85,9 @@ func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-c // If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more // requests, and return when token is ready. if bc.collectedRequestCount >= bc.maxBatchSize && !tokenAcquired { + if tokenCh == nil { + return nil + } select { case <-ctx.Done(): return ctx.Err() @@ -88,20 +96,23 @@ func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-c } } - select { - case <-ctx.Done(): - return ctx.Err() - case req := <-requestCh: - // Start to batch when the first request arrives. - bc.pushRequest(req) - // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next - // request if it arrives. - continue - case <-tokenCh: - tokenAcquired = true + if tokenCh != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requestCh: + // Start to batch when the first request arrives. + bc.pushRequest(req) + // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next + // request if it arrives. + continue + case <-tokenCh: + tokenAcquired = true + } } - // The token is ready. If the first request didn't arrive, wait for it. + // After the token is ready or it's working without token, + // wait for the first request to arrive. if bc.collectedRequestCount == 0 { select { case <-ctx.Done(): diff --git a/client/pkg/batch/batch_controller_test.go b/client/pkg/batch/batch_controller_test.go index 7c9ffa6944f..92aef14bd35 100644 --- a/client/pkg/batch/batch_controller_test.go +++ b/client/pkg/batch/batch_controller_test.go @@ -21,9 +21,11 @@ import ( "github.com/stretchr/testify/require" ) +const testMaxBatchSize = 20 + func TestAdjustBestBatchSize(t *testing.T) { re := require.New(t) - bc := NewController[int](20, nil, nil) + bc := NewController[int](testMaxBatchSize, nil, nil) re.Equal(defaultBestBatchSize, bc.bestBatchSize) bc.AdjustBestBatchSize() re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) @@ -52,7 +54,7 @@ type testRequest struct { func TestFinishCollectedRequests(t *testing.T) { re := require.New(t) - bc := NewController[*testRequest](20, nil, nil) + bc := NewController[*testRequest](testMaxBatchSize, nil, nil) // Finish with zero request count. re.Zero(bc.collectedRequestCount) bc.FinishCollectedRequests(nil, nil) @@ -81,3 +83,58 @@ func TestFinishCollectedRequests(t *testing.T) { re.Equal(context.Canceled, requests[i].err) } } + +func TestFetchPendingRequests(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + re := require.New(t) + bc := NewController[int](testMaxBatchSize, nil, nil) + requestCh := make(chan int, testMaxBatchSize+1) + // Fetch a nil `tokenCh`. + requestCh <- 1 + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Empty(requestCh) + re.Equal(1, bc.collectedRequestCount) + // Fetch a nil `tokenCh` with max batch size. + for i := range testMaxBatchSize { + requestCh <- i + } + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Empty(requestCh) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Fetch a nil `tokenCh` with max batch size + 1. + for i := range testMaxBatchSize + 1 { + requestCh <- i + } + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Len(requestCh, 1) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Drain the requestCh. + <-requestCh + // Fetch a non-nil `tokenCh`. + tokenCh := make(chan struct{}, 1) + requestCh <- 1 + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Empty(requestCh) + re.Equal(1, bc.collectedRequestCount) + // Fetch a non-nil `tokenCh` with max batch size. + for i := range testMaxBatchSize { + requestCh <- i + } + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Empty(requestCh) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Fetch a non-nil `tokenCh` with max batch size + 1. + for i := range testMaxBatchSize + 1 { + requestCh <- i + } + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Len(requestCh, 1) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Drain the requestCh. + <-requestCh +}