diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index db5c3e82d91f..957930c0b1c8 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -350,6 +350,7 @@ const ( defaultLoadFromEtcdRetryTimes = int(defaultLoadDataFromEtcdTimeout / defaultLoadFromEtcdRetryInterval) defaultLoadBatchSize = 400 defaultWatchChangeRetryInterval = 1 * time.Second + defaultForceLoadMinimalInterval = 200 * time.Millisecond ) // LoopWatcher loads data from etcd and sets a watcher for it. @@ -376,6 +377,11 @@ type LoopWatcher struct { // postEventFn is used to call after handling all events. postEventFn func() error + // forceLoadMu is used to ensure two force loads have minimal interval. + forceLoadMu sync.Mutex + // lastTimeForceLoad is used to record the last time force loading data from etcd. + lastTimeForceLoad time.Time + // loadTimeout is used to set the timeout for loading data from etcd. loadTimeout time.Duration // loadRetryTimes is used to set the retry times for loading data from etcd. @@ -405,6 +411,7 @@ func NewLoopWatcher(ctx context.Context, wg *sync.WaitGroup, client *clientv3.Cl deleteFn: deleteFn, postEventFn: postEventFn, opts: opts, + lastTimeForceLoad: time.Now(), loadTimeout: defaultLoadDataFromEtcdTimeout, loadRetryTimes: defaultLoadFromEtcdRetryTimes, loadBatchSize: defaultLoadBatchSize, @@ -601,6 +608,14 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) // ForceLoad forces to load the key. func (lw *LoopWatcher) ForceLoad() { + lw.forceLoadMu.Lock() + if time.Since(lw.lastTimeForceLoad) < defaultForceLoadMinimalInterval { + lw.forceLoadMu.Unlock() + return + } + lw.lastTimeForceLoad = time.Now() + lw.forceLoadMu.Unlock() + select { case lw.forceLoadCh <- struct{}{}: default: diff --git a/pkg/utils/tsoutil/tso_request.go b/pkg/utils/tsoutil/tso_request.go index b2459ccb371b..b690927b09b6 100644 --- a/pkg/utils/tsoutil/tso_request.go +++ b/pkg/utils/tsoutil/tso_request.go @@ -31,7 +31,7 @@ type Request interface { // getCount returns the count of timestamps to retrieve getCount() uint32 // process sends request and receive response via stream. - // count defins the count of timestamps to retrieve. + // count defines the count of timestamps to retrieve. process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error) // postProcess sends the response back to the sender of the request postProcess(countSum, physical, firstLogical int64, suffixBits uint32) (int64, error) @@ -50,7 +50,7 @@ type TSOProtoRequest struct { stream tsopb.TSO_TsoServer } -// NewTSOProtoRequest creats a TSOProtoRequest and returns as a Request +// NewTSOProtoRequest creates a TSOProtoRequest and returns as a Request func NewTSOProtoRequest(forwardedHost string, clientConn *grpc.ClientConn, request *tsopb.TsoRequest, stream tsopb.TSO_TsoServer) Request { tsoRequest := &TSOProtoRequest{ forwardedHost: forwardedHost, @@ -77,7 +77,7 @@ func (r *TSOProtoRequest) getCount() uint32 { } // process sends request and receive response via stream. -// count defins the count of timestamps to retrieve. +// count defines the count of timestamps to retrieve. func (r *TSOProtoRequest) process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error) { return forwardStream.process(r.request.GetHeader().GetClusterId(), count, r.request.GetHeader().GetKeyspaceId(), r.request.GetHeader().GetKeyspaceGroupId(), r.request.GetDcLocation()) @@ -111,7 +111,7 @@ type PDProtoRequest struct { stream pdpb.PD_TsoServer } -// NewPDProtoRequest creats a PDProtoRequest and returns as a Request +// NewPDProtoRequest creates a PDProtoRequest and returns as a Request func NewPDProtoRequest(forwardedHost string, clientConn *grpc.ClientConn, request *pdpb.TsoRequest, stream pdpb.PD_TsoServer) Request { tsoRequest := &PDProtoRequest{ forwardedHost: forwardedHost, @@ -138,7 +138,7 @@ func (r *PDProtoRequest) getCount() uint32 { } // process sends request and receive response via stream. -// count defins the count of timestamps to retrieve. +// count defines the count of timestamps to retrieve. func (r *PDProtoRequest) process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error) { return forwardStream.process(r.request.GetHeader().GetClusterId(), count, utils.DefaultKeyspaceID, utils.DefaultKeyspaceGroupID, r.request.GetDcLocation()) diff --git a/server/grpc_service.go b/server/grpc_service.go index 9b349f99335f..1d1ddc9e23fb 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -64,6 +64,7 @@ var ( ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address") + ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout") ) // GrpcServer wraps Server to provide grpc service. @@ -324,6 +325,10 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb // Tso implements gRPC PDServer. func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { + if s.IsAPIServiceMode() { + return s.forwardTSO(stream) + } + var ( doneCh chan struct{} errCh chan error @@ -361,15 +366,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { errCh = make(chan error) } - var tsoProtoFactory tsoutil.ProtoFactory - if s.IsAPIServiceMode() { - tsoProtoFactory = s.tsoProtoFactory - } else { - tsoProtoFactory = s.pdProtoFactory - } - tsoRequest := tsoutil.NewPDProtoRequest(forwardedHost, clientConn, request, stream) - s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh, s.tsoPrimaryWatcher) + s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, s.pdProtoFactory, doneCh, errCh, s.tsoPrimaryWatcher) continue } @@ -379,7 +377,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return status.Errorf(codes.Unknown, "server not started") } if request.GetHeader().GetClusterId() != s.clusterID { - return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) + return status.Errorf(codes.FailedPrecondition, + "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId()) } count := request.GetCount() ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count) @@ -398,6 +397,162 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } } +// forwardTSO forward the TSO requests to the TSO service. +func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { + var ( + server = &tsoServer{stream: stream} + forwardStream tsopb.TSO_TsoClient + cancel context.CancelFunc + lastForwardedHost string + ) + defer func() { + if forwardStream != nil { + forwardStream.CloseSend() + } + // cancel the forward stream + if cancel != nil { + cancel() + } + }() + + for { + select { + case <-s.ctx.Done(): + return errors.WithStack(s.ctx.Err()) + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + request, err := server.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Errorf(codes.Unknown, err.Error()) + } + + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + return errors.WithStack(ErrNotFoundTSOAddr) + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if forwardStream != nil { + forwardStream.CloseSend() + } + if cancel != nil { + cancel() + } + + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return errors.WithStack(err) + } + forwardStream, cancel, err = s.createTSOForwardStream(clientConn) + if err != nil { + return errors.WithStack(err) + } + lastForwardedHost = forwardedHost + } + + tsoReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + DcLocation: request.GetDcLocation(), + } + if err := forwardStream.Send(tsoReq); err != nil { + return errors.WithStack(err) + } + + tsopbResp, err := forwardStream.Recv() + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + } + return errors.WithStack(err) + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), + } + } + } + + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if err := server.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +// tsoServer wraps PD_TsoServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type tsoServer struct { + stream pdpb.PD_TsoServer + closed int32 +} + +func (s *tsoServer) Send(m *pdpb.TsoResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + done <- s.stream.Send(m) + }() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-time.After(tsoutil.DefaultTSOProxyTimeout): + atomic.StoreInt32(&s.closed, 1) + return ErrForwardTSOTimeout + } +} + +func (s *tsoServer) Recv() (*pdpb.TsoRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + req, err := s.stream.Recv() + if err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(err) + } + return req, nil +} + func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context) (forwardedHost string, err error) { if s.IsAPIServiceMode() { var ok bool @@ -1875,6 +2030,15 @@ func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatC } } +func (s *GrpcServer) createTSOForwardStream(client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go checkStream(ctx, cancel, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { done := make(chan struct{}) ctx, cancel := context.WithCancel(s.ctx) diff --git a/tests/integrations/mcs/tso/proxy_test.go b/tests/integrations/mcs/tso/proxy_test.go new file mode 100644 index 000000000000..f08e5e363e77 --- /dev/null +++ b/tests/integrations/mcs/tso/proxy_test.go @@ -0,0 +1,376 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tso + +import ( + "context" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/tikv/pd/client/tsoutil" + "github.com/tikv/pd/tests" + "github.com/tikv/pd/tests/integrations/mcs" + "google.golang.org/grpc" +) + +type tsoProxyTestSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + apiCluster *tests.TestCluster + apiLeader *tests.TestServer + backendEndpoints string + tsoCluster *mcs.TestTSOCluster + defaultReq *pdpb.TsoRequest + grpcClientConns []*grpc.ClientConn + streams []pdpb.PD_TsoClient + cancelFuncs []context.CancelFunc +} + +func TestTSOProxyTestSuite(t *testing.T) { + suite.Run(t, new(tsoProxyTestSuite)) +} + +func (s *tsoProxyTestSuite) SetupSuite() { + re := s.Require() + + var err error + s.ctx, s.cancel = context.WithCancel(context.Background()) + // Create an API cluster with 1 server + s.apiCluster, err = tests.NewTestAPICluster(s.ctx, 1) + re.NoError(err) + err = s.apiCluster.RunInitialServers() + re.NoError(err) + leaderName := s.apiCluster.WaitLeader() + s.apiLeader = s.apiCluster.GetServer(leaderName) + s.backendEndpoints = s.apiLeader.GetAddr() + s.NoError(s.apiLeader.BootstrapCluster()) + + // Create a TSO cluster with 2 servers + s.tsoCluster, err = mcs.NewTestTSOCluster(s.ctx, 2, s.backendEndpoints) + re.NoError(err) + s.tsoCluster.WaitForDefaultPrimaryServing(re) + + s.defaultReq = &pdpb.TsoRequest{ + Header: &pdpb.RequestHeader{ClusterId: s.apiLeader.GetClusterID()}, + Count: 1, + } + + // Create some TSO client streams with the same context. + s.grpcClientConns, s.streams, s.cancelFuncs = createTSOStreams(re, s.ctx, s.backendEndpoints, 100, true) + // Create some TSO client streams with the different context. + grpcClientConns, streams, cancelFuncs := createTSOStreams(re, s.ctx, s.backendEndpoints, 100, false) + s.grpcClientConns = append(s.grpcClientConns, grpcClientConns...) + s.streams = append(s.streams, streams...) + s.cancelFuncs = append(s.cancelFuncs, cancelFuncs...) +} + +func (s *tsoProxyTestSuite) TearDownSuite() { + s.cleanupGRPCStreams(s.grpcClientConns, s.streams, s.cancelFuncs) + s.tsoCluster.Destroy() + s.apiCluster.Destroy() + s.cancel() +} + +// TestTSOProxyBasic tests the TSO Proxy's basic function to forward TSO requests to TSO microservice. +// It also verifies the correctness of the TSO Proxy's TSO response, such as the count of timestamps +// to retrieve in one TSO request and the monotonicity of the returned timestamps. +func (s *tsoProxyTestSuite) TestTSOProxyBasic() { + for i := 0; i < 10; i++ { + s.verifyTSOProxy(s.streams, 100, true) + } +} + +// TestTSOProxyWithLargeCount tests while some grpc streams being cancelled and the others are still +// working, the TSO Proxy can still work correctly. +func (s *tsoProxyTestSuite) TestTSOProxyWorksWithCancellation() { + re := s.Require() + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + go func() { + defer wg.Done() + for i := 0; i < 5; i++ { + grpcClientConns, streams, cancelFuncs := createTSOStreams(re, s.ctx, s.backendEndpoints, 10, false) + for j := 0; j < 10; j++ { + s.verifyTSOProxy(streams, 10, true) + } + s.cleanupGRPCStreams(grpcClientConns, streams, cancelFuncs) + } + }() + for i := 0; i < 20; i++ { + s.verifyTSOProxy(s.streams, 100, true) + } + }() + wg.Wait() +} + +// TestTSOProxyStress tests the TSO Proxy can work correctly under the stress. gPRC and TSO failures are allowed, +// but the TSO Proxy should not panic, blocked or deadlocked, and if it returns a timestamp, it should be a valid +// timestamp monotonic increasing. After the stress, the TSO Proxy should still work correctly. +func (s *tsoProxyTestSuite) TestTSOProxyStress() { + s.T().Skip("skip the stress test temporarily") + re := s.Require() + // Add 1000 concurrent clients each round; 2 runs in total, and 2000 concurrent clients are created in total. + grpcClientConns := make([]*grpc.ClientConn, 0) + streams := make([]pdpb.PD_TsoClient, 0) + cancelFuncs := make([]context.CancelFunc, 0) + for i := 0; i < 2; i++ { + fmt.Printf("Start the %dth round of stress test with %d concurrent clients.\n", i, len(streams)+1000) + grpcClientConnsTemp, streamsTemp, cancelFuncsTemp := createTSOStreams(re, s.ctx, s.backendEndpoints, 1000, false) + grpcClientConns = append(grpcClientConns, grpcClientConnsTemp...) + streams = append(streams, streamsTemp...) + cancelFuncs = append(cancelFuncs, cancelFuncsTemp...) + s.verifyTSOProxy(streams, 50, false) + } + s.cleanupGRPCStreams(grpcClientConns, streams, cancelFuncs) + + // Wait for the TSO Proxy to recover from the stress. Treat 3 seconds as our SLA. + time.Sleep(3 * time.Second) + + for i := 0; i < 10; i++ { + s.verifyTSOProxy(s.streams, 100, true) + } +} + +func (s *tsoProxyTestSuite) cleanupGRPCStreams( + grpcClientConns []*grpc.ClientConn, streams []pdpb.PD_TsoClient, cancelFuncs []context.CancelFunc, +) { + for _, stream := range streams { + stream.CloseSend() + } + for _, conn := range grpcClientConns { + conn.Close() + } + for _, cancelFun := range cancelFuncs { + cancelFun() + } +} + +// verifyTSOProxy verifies the TSO Proxy can work correctly. +// +// 1. If mustReliable == true +// no gPRC or TSO failures, the TSO Proxy should return a valid timestamp monotonic increasing. +// +// 2. If mustReliable == false +// gPRC and TSO failures are allowed, but the TSO Proxy should not panic, blocked or deadlocked. +// If it returns a timestamp, it should be a valid timestamp monotonic increasing. +func (s *tsoProxyTestSuite) verifyTSOProxy( + streams []pdpb.PD_TsoClient, requestsPerClient int, mustReliable bool, +) { + re := s.Require() + reqs := s.generateRequests(requestsPerClient) + + wg := &sync.WaitGroup{} + for _, stream := range streams { + streamCopy := stream + wg.Add(1) + go func(streamCopy pdpb.PD_TsoClient) { + defer wg.Done() + lastPhysical, lastLogical := int64(0), int64(0) + for i := 0; i < requestsPerClient; i++ { + req := reqs[rand.Intn(requestsPerClient)] + err := streamCopy.Send(req) + if err != nil && !mustReliable { + continue + } + re.NoError(err) + resp, err := streamCopy.Recv() + if err != nil && !mustReliable { + continue + } + re.NoError(err) + re.Equal(req.GetCount(), resp.GetCount()) + ts := resp.GetTimestamp() + count := int64(resp.GetCount()) + physical, largestLogic, suffixBits := ts.GetPhysical(), ts.GetLogical(), ts.GetSuffixBits() + firstLogical := tsoutil.AddLogical(largestLogic, -count+1, suffixBits) + re.False(tsoutil.TSLessEqual(physical, firstLogical, lastPhysical, lastLogical)) + } + }(streamCopy) + } + wg.Wait() +} + +func (s *tsoProxyTestSuite) generateRequests(requestsPerClient int) []*pdpb.TsoRequest { + reqs := make([]*pdpb.TsoRequest, requestsPerClient) + for i := 0; i < requestsPerClient; i++ { + reqs[i] = &pdpb.TsoRequest{ + Header: &pdpb.RequestHeader{ClusterId: s.apiLeader.GetClusterID()}, + Count: uint32(i) + 1, // Make sure the count is positive. + } + } + return reqs +} + +// createTSOStreams creates multiple TSO client streams, and each stream uses a different gRPC connection +// to simulate multiple clients. +func createTSOStreams( + re *require.Assertions, ctx context.Context, + backendEndpoints string, clientCount int, sameContext bool, +) ([]*grpc.ClientConn, []pdpb.PD_TsoClient, []context.CancelFunc) { + grpcClientConns := make([]*grpc.ClientConn, 0, clientCount) + streams := make([]pdpb.PD_TsoClient, 0, clientCount) + cancelFuncs := make([]context.CancelFunc, 0, clientCount) + + for i := 0; i < clientCount; i++ { + conn, err := grpc.Dial(strings.TrimPrefix(backendEndpoints, "http://"), grpc.WithInsecure()) + re.NoError(err) + grpcClientConns = append(grpcClientConns, conn) + grpcPDClient := pdpb.NewPDClient(conn) + var stream pdpb.PD_TsoClient + if sameContext { + stream, err = grpcPDClient.Tso(ctx) + re.NoError(err) + } else { + cctx, cancel := context.WithCancel(ctx) + cancelFuncs = append(cancelFuncs, cancel) + stream, err = grpcPDClient.Tso(cctx) + re.NoError(err) + } + streams = append(streams, stream) + } + + return grpcClientConns, streams, cancelFuncs +} + +func tsoProxy( + tsoReq *pdpb.TsoRequest, streams []pdpb.PD_TsoClient, + concurrentClient bool, requestsPerClient int, +) error { + if concurrentClient { + wg := &sync.WaitGroup{} + errsReturned := make([]error, len(streams)) + for index, stream := range streams { + streamCopy := stream + wg.Add(1) + go func(index int, streamCopy pdpb.PD_TsoClient) { + defer wg.Done() + for i := 0; i < requestsPerClient; i++ { + if err := streamCopy.Send(tsoReq); err != nil { + errsReturned[index] = err + return + } + if _, err := streamCopy.Recv(); err != nil { + return + } + } + }(index, streamCopy) + } + wg.Wait() + for _, err := range errsReturned { + if err != nil { + return err + } + } + } else { + for _, stream := range streams { + for i := 0; i < requestsPerClient; i++ { + if err := stream.Send(tsoReq); err != nil { + return err + } + if _, err := stream.Recv(); err != nil { + return err + } + } + } + } + return nil +} + +var benmarkTSOProxyTable = []struct { + concurrentClient bool + requestsPerClient int +}{ + {true, 2}, + {true, 10}, + {true, 100}, + {false, 2}, + {false, 10}, + {false, 100}, +} + +// BenchmarkTSOProxy10ClientsSameContext benchmarks TSO proxy performance with 10 clients and the same context. +func BenchmarkTSOProxy10ClientsSameContext(b *testing.B) { + benchmarkTSOProxyNClients(10, true, b) +} + +// BenchmarkTSOProxy10ClientsDiffContext benchmarks TSO proxy performance with 10 clients and different contexts. +func BenchmarkTSOProxy10ClientsDiffContext(b *testing.B) { + benchmarkTSOProxyNClients(10, false, b) +} + +// BenchmarkTSOProxy100ClientsSameContext benchmarks TSO proxy performance with 100 clients and the same context. +func BenchmarkTSOProxy100ClientsSameContext(b *testing.B) { + benchmarkTSOProxyNClients(100, true, b) +} + +// BenchmarkTSOProxy100ClientsDiffContext benchmarks TSO proxy performance with 100 clients and different contexts. +func BenchmarkTSOProxy100ClientsDiffContext(b *testing.B) { + benchmarkTSOProxyNClients(100, false, b) +} + +// BenchmarkTSOProxy1000ClientsSameContext benchmarks TSO proxy performance with 1000 clients and the same context. +func BenchmarkTSOProxy1000ClientsSameContext(b *testing.B) { + benchmarkTSOProxyNClients(1000, true, b) +} + +// BenchmarkTSOProxy1000ClientsDiffContext benchmarks TSO proxy performance with 1000 clients and different contexts. +func BenchmarkTSOProxy1000ClientsDiffContext(b *testing.B) { + benchmarkTSOProxyNClients(1000, false, b) +} + +// benchmarkTSOProxyNClients benchmarks TSO proxy performance. +func benchmarkTSOProxyNClients(clientCount int, sameContext bool, b *testing.B) { + suite := new(tsoProxyTestSuite) + suite.SetT(&testing.T{}) + suite.SetupSuite() + re := suite.Require() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + grpcClientConns, streams, cancelFuncs := createTSOStreams(re, ctx, suite.backendEndpoints, clientCount, sameContext) + + // Benchmark TSO proxy + b.ResetTimer() + for _, t := range benmarkTSOProxyTable { + var builder strings.Builder + if t.concurrentClient { + builder.WriteString("ConcurrentClients_") + } else { + builder.WriteString("SequentialClients_") + } + b.Run(fmt.Sprintf("%s_%dReqsPerClient", builder.String(), t.requestsPerClient), func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := tsoProxy(suite.defaultReq, streams, t.concurrentClient, t.requestsPerClient) + re.NoError(err) + } + }) + } + b.StopTimer() + + suite.cleanupGRPCStreams(grpcClientConns, streams, cancelFuncs) + + suite.TearDownSuite() +}