-
Notifications
You must be signed in to change notification settings - Fork 720
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
client: genericize the batch controller (#8793)
ref #8690 Genericize the batch controller for the later reuse. Signed-off-by: JmPotato <ghzpotato@gmail.com>
- Loading branch information
Showing
4 changed files
with
366 additions
and
257 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
// Copyright 2024 TiKV Project Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package pd | ||
|
||
import ( | ||
"context" | ||
"time" | ||
|
||
"github.com/prometheus/client_golang/prometheus" | ||
) | ||
|
||
// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4). | ||
const defaultBestBatchSize = 8 | ||
|
||
// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error. | ||
type finisherFunc[T any] func(int, T, error) | ||
|
||
type batchController[T any] struct { | ||
maxBatchSize int | ||
// bestBatchSize is a dynamic size that changed based on the current batch effect. | ||
bestBatchSize int | ||
|
||
collectedRequests []T | ||
collectedRequestCount int | ||
|
||
// The finisher function to cancel collected requests when an internal error occurs. | ||
finisher finisherFunc[T] | ||
// The observer to record the best batch size. | ||
bestBatchObserver prometheus.Histogram | ||
// The time after getting the first request and the token, and before performing extra batching. | ||
extraBatchingStartTime time.Time | ||
} | ||
|
||
func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] { | ||
return &batchController[T]{ | ||
maxBatchSize: maxBatchSize, | ||
bestBatchSize: defaultBestBatchSize, | ||
collectedRequests: make([]T, maxBatchSize+1), | ||
collectedRequestCount: 0, | ||
finisher: finisher, | ||
bestBatchObserver: bestBatchObserver, | ||
} | ||
} | ||
|
||
// fetchPendingRequests will start a new round of the batch collecting from the channel. | ||
// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. | ||
// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled | ||
// when the function returns, so the caller don't need to clear them manually. | ||
func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { | ||
var tokenAcquired bool | ||
defer func() { | ||
if errRet != nil { | ||
// Something went wrong when collecting a batch of requests. Release the token and cancel collected requests | ||
// if any. | ||
if tokenAcquired { | ||
tokenCh <- struct{}{} | ||
} | ||
bc.finishCollectedRequests(bc.finisher, errRet) | ||
} | ||
}() | ||
|
||
// Wait until BOTH the first request and the token have arrived. | ||
// TODO: `bc.collectedRequestCount` should never be non-empty here. Consider do assertion here. | ||
bc.collectedRequestCount = 0 | ||
for { | ||
// If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more | ||
// requests, and return when token is ready. | ||
if bc.collectedRequestCount >= bc.maxBatchSize && !tokenAcquired { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case <-tokenCh: | ||
return nil | ||
} | ||
} | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case req := <-requestCh: | ||
// Start to batch when the first request arrives. | ||
bc.pushRequest(req) | ||
// A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next | ||
// request if it arrives. | ||
continue | ||
case <-tokenCh: | ||
tokenAcquired = true | ||
} | ||
|
||
// The token is ready. If the first request didn't arrive, wait for it. | ||
if bc.collectedRequestCount == 0 { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case firstRequest := <-requestCh: | ||
bc.pushRequest(firstRequest) | ||
} | ||
} | ||
|
||
// Both token and the first request have arrived. | ||
break | ||
} | ||
|
||
bc.extraBatchingStartTime = time.Now() | ||
|
||
// This loop is for trying best to collect more requests, so we use `bc.maxBatchSize` here. | ||
fetchPendingRequestsLoop: | ||
for bc.collectedRequestCount < bc.maxBatchSize { | ||
select { | ||
case req := <-requestCh: | ||
bc.pushRequest(req) | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
default: | ||
break fetchPendingRequestsLoop | ||
} | ||
} | ||
|
||
// Check whether we should fetch more pending requests from the channel. | ||
if bc.collectedRequestCount >= bc.maxBatchSize || maxBatchWaitInterval <= 0 { | ||
return nil | ||
} | ||
|
||
// Fetches more pending requests from the channel. | ||
// Try to collect `bc.bestBatchSize` requests, or wait `maxBatchWaitInterval` | ||
// when `bc.collectedRequestCount` is less than the `bc.bestBatchSize`. | ||
if bc.collectedRequestCount < bc.bestBatchSize { | ||
after := time.NewTimer(maxBatchWaitInterval) | ||
defer after.Stop() | ||
for bc.collectedRequestCount < bc.bestBatchSize { | ||
select { | ||
case req := <-requestCh: | ||
bc.pushRequest(req) | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case <-after.C: | ||
return nil | ||
} | ||
} | ||
} | ||
|
||
// Do an additional non-block try. Here we test the length with `bc.maxBatchSize` instead | ||
// of `bc.bestBatchSize` because trying best to fetch more requests is necessary so that | ||
// we can adjust the `bc.bestBatchSize` dynamically later. | ||
for bc.collectedRequestCount < bc.maxBatchSize { | ||
select { | ||
case req := <-requestCh: | ||
bc.pushRequest(req) | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
default: | ||
return nil | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly | ||
// before calling this function. | ||
func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error { | ||
batchingLoop: | ||
for bc.collectedRequestCount < bc.maxBatchSize { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case req := <-requestCh: | ||
bc.pushRequest(req) | ||
case <-timer.C: | ||
break batchingLoop | ||
} | ||
} | ||
|
||
// Try to collect more requests in non-blocking way. | ||
nonWaitingBatchLoop: | ||
for bc.collectedRequestCount < bc.maxBatchSize { | ||
select { | ||
case <-ctx.Done(): | ||
return ctx.Err() | ||
case req := <-requestCh: | ||
bc.pushRequest(req) | ||
default: | ||
break nonWaitingBatchLoop | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (bc *batchController[T]) pushRequest(req T) { | ||
bc.collectedRequests[bc.collectedRequestCount] = req | ||
bc.collectedRequestCount++ | ||
} | ||
|
||
func (bc *batchController[T]) getCollectedRequests() []T { | ||
return bc.collectedRequests[:bc.collectedRequestCount] | ||
} | ||
|
||
// adjustBestBatchSize stabilizes the latency with the AIAD algorithm. | ||
func (bc *batchController[T]) adjustBestBatchSize() { | ||
if bc.bestBatchObserver != nil { | ||
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize)) | ||
} | ||
length := bc.collectedRequestCount | ||
if length < bc.bestBatchSize && bc.bestBatchSize > 1 { | ||
// Waits too long to collect requests, reduce the target batch size. | ||
bc.bestBatchSize-- | ||
} else if length > bc.bestBatchSize+4 /* Hard-coded number, in order to make `bc.bestBatchSize` stable */ && | ||
bc.bestBatchSize < bc.maxBatchSize { | ||
bc.bestBatchSize++ | ||
} | ||
} | ||
|
||
func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) { | ||
if finisher == nil { | ||
finisher = bc.finisher | ||
} | ||
if finisher != nil { | ||
for i := range bc.collectedRequestCount { | ||
finisher(i, bc.collectedRequests[i], err) | ||
} | ||
} | ||
// Prevent the finished requests from being processed again. | ||
bc.collectedRequestCount = 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright 2024 TiKV Project Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package pd | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestAdjustBestBatchSize(t *testing.T) { | ||
re := require.New(t) | ||
bc := newBatchController[int](20, nil, nil) | ||
re.Equal(defaultBestBatchSize, bc.bestBatchSize) | ||
bc.adjustBestBatchSize() | ||
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) | ||
// Clear the collected requests. | ||
bc.finishCollectedRequests(nil, nil) | ||
// Push 10 requests - do not increase the best batch size. | ||
for i := range 10 { | ||
bc.pushRequest(i) | ||
} | ||
bc.adjustBestBatchSize() | ||
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) | ||
bc.finishCollectedRequests(nil, nil) | ||
// Push 15 requests, increase the best batch size. | ||
for i := range 15 { | ||
bc.pushRequest(i) | ||
} | ||
bc.adjustBestBatchSize() | ||
re.Equal(defaultBestBatchSize, bc.bestBatchSize) | ||
bc.finishCollectedRequests(nil, nil) | ||
} | ||
|
||
type testRequest struct { | ||
idx int | ||
err error | ||
} | ||
|
||
func TestFinishCollectedRequests(t *testing.T) { | ||
re := require.New(t) | ||
bc := newBatchController[*testRequest](20, nil, nil) | ||
// Finish with zero request count. | ||
re.Zero(bc.collectedRequestCount) | ||
bc.finishCollectedRequests(nil, nil) | ||
re.Zero(bc.collectedRequestCount) | ||
// Finish with non-zero request count. | ||
requests := make([]*testRequest, 10) | ||
for i := range 10 { | ||
requests[i] = &testRequest{} | ||
bc.pushRequest(requests[i]) | ||
} | ||
re.Equal(10, bc.collectedRequestCount) | ||
bc.finishCollectedRequests(nil, nil) | ||
re.Zero(bc.collectedRequestCount) | ||
// Finish with custom finisher. | ||
for i := range 10 { | ||
requests[i] = &testRequest{} | ||
bc.pushRequest(requests[i]) | ||
} | ||
bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) { | ||
tr.idx = idx | ||
tr.err = err | ||
}, context.Canceled) | ||
re.Zero(bc.collectedRequestCount) | ||
for i := range 10 { | ||
re.Equal(i, requests[i].idx) | ||
re.Equal(context.Canceled, requests[i].err) | ||
} | ||
} |
Oops, something went wrong.