diff --git a/balancer/leastrequest/balancer_test.go b/balancer/leastrequest/balancer_test.go index f79e3850c6c0..e0043db391c8 100644 --- a/balancer/leastrequest/balancer_test.go +++ b/balancer/leastrequest/balancer_test.go @@ -117,14 +117,9 @@ func (s) TestParseConfig(t *testing.T) { } } -// setupBackends spins up three test backends, each listening on a port on -// localhost. The three backends always reply with an empty response with no -// error, and for streaming receive until hitting an EOF error. -func setupBackends(t *testing.T) []string { - t.Helper() - const numBackends = 3 - addresses := make([]string, numBackends) - // Construct and start three working backends. +func startBackends(t *testing.T, numBackends int) []*stubserver.StubServer { + backends := make([]*stubserver.StubServer, 0, numBackends) + // Construct and start working backends. for i := 0; i < numBackends; i++ { backend := &stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { @@ -140,7 +135,21 @@ func setupBackends(t *testing.T) []string { } t.Logf("Started good TestService backend at: %q", backend.Address) t.Cleanup(func() { backend.Stop() }) - addresses[i] = backend.Address + backends = append(backends, backend) + } + return backends +} + +// setupBackends spins up three test backends, each listening on a port on +// localhost. The three backends always reply with an empty response with no +// error, and for streaming receive until hitting an EOF error. +func setupBackends(t *testing.T, numBackends int) []string { + t.Helper() + addresses := make([]string, numBackends) + backends := startBackends(t, numBackends) + // Construct and start working backends. + for i := 0; i < numBackends; i++ { + addresses[i] = backends[i].Address } return addresses } @@ -205,7 +214,7 @@ func (s) TestLeastRequestE2E(t *testing.T) { index++ return ret } - addresses := setupBackends(t) + addresses := setupBackends(t, 3) mr := manual.NewBuilderWithScheme("lr-e2e") defer mr.Close() @@ -321,7 +330,7 @@ func (s) TestLeastRequestPersistsCounts(t *testing.T) { index++ return ret } - addresses := setupBackends(t) + addresses := setupBackends(t, 3) mr := manual.NewBuilderWithScheme("lr-e2e") defer mr.Close() @@ -462,7 +471,7 @@ func (s) TestLeastRequestPersistsCounts(t *testing.T) { // and makes 100 RPCs asynchronously. This makes sure no race conditions happen // in this scenario. func (s) TestConcurrentRPCs(t *testing.T) { - addresses := setupBackends(t) + addresses := setupBackends(t, 3) mr := manual.NewBuilderWithScheme("lr-e2e") defer mr.Close() @@ -508,5 +517,192 @@ func (s) TestConcurrentRPCs(t *testing.T) { }() } wg.Wait() +} + +// Test tests that the least request balancer persists RPC counts once it gets +// new picker updates and backends within an endpoint go down. It first updates +// the balancer with two endpoints having two addresses each. It verifies the +// requests are round robined across the first address of each endpoint. It then +// stops the active backend in endpoint[0]. It verified that the balancer starts +// using the second address in endpoint[0]. The test then creates a bunch of +// streams on two endpoints. Then, it updates the balancer with three endpoints, +// including the two previous. Any created streams should then be started on the +// new endpoint. The test shuts down the active backed in endpoint[1] and +// endpoint[2]. The test verifies that new RPCs are round robined across the +// active backends in endpoint[1] and endpoint[2]. +func (s) TestLeastRequestEndpoints_MultipleAddresses(t *testing.T) { + defer func(u func() uint32) { + randuint32 = u + }(randuint32) + var index int + indexes := []uint32{ + 0, 0, 1, 1, + } + randuint32 = func() uint32 { + ret := indexes[index%len(indexes)] + index++ + return ret + } + backends := startBackends(t, 6) + mr := manual.NewBuilderWithScheme("lr-e2e") + defer mr.Close() + + // Configure least request as top level balancer of channel. + lrscJSON := ` +{ + "loadBalancingConfig": [ + { + "least_request_experimental": { + "choiceCount": 2 + } + } + ] +}` + endpoints := []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: backends[0].Address}, {Addr: backends[1].Address}}}, + {Addresses: []resolver.Address{{Addr: backends[2].Address}, {Addr: backends[3].Address}}}, + {Addresses: []resolver.Address{{Addr: backends[4].Address}, {Addr: backends[5].Address}}}, + } + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON) + firstTwoEndpoints := []resolver.Endpoint{endpoints[0], endpoints[1]} + mr.InitialState(resolver.State{ + Endpoints: firstTwoEndpoints, + ServiceConfig: sc, + }) + + cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient() failed: %v", err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + testServiceClient := testgrpc.NewTestServiceClient(cc) + + // Wait for the two backends to round robin across. The happens because a + // child pickfirst transitioning into READY causes a new picker update. Once + // the picker update with the two backends is present, this test can start + // to populate those backends with streams. + wantAddrs := []resolver.Address{ + endpoints[0].Addresses[0], + endpoints[1].Addresses[0], + } + if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil { + t.Fatalf("error in expected round robin: %v", err) + } + + // Shut down one of the addresses in endpoints[0], the child pickfirst + // should fallback to the next address in endpoints[0]. + backends[0].Stop() + wantAddrs = []resolver.Address{ + endpoints[0].Addresses[1], + endpoints[1].Addresses[0], + } + if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil { + t.Fatalf("error in expected round robin: %v", err) + } + + // Start 50 streaming RPCs, and leave them unfinished for the duration of + // the test. This will populate the first two endpoints with many active + // RPCs. + for i := 0; i < 50; i++ { + _, err := testServiceClient.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err) + } + } + // Update the least request balancer to choice count 3. Also update the + // address list adding a third endpoint. Alongside the injected randomness, + // this should trigger the least request balancer to search all created + // endpoints. Thus, since endpoint 3 is the new endpoint and the first two + // endpoint are populated with RPCs, once the picker update of all 3 READY + // pickfirsts takes effect, all new streams should be started on endpoint 3. + index = 0 + indexes = []uint32{ + 0, 1, 2, 3, 4, 5, + } + lrscJSON = ` +{ + "loadBalancingConfig": [ + { + "least_request_experimental": { + "choiceCount": 3 + } + } + ] +}` + sc = internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(lrscJSON) + mr.UpdateState(resolver.State{ + Endpoints: endpoints, + ServiceConfig: sc, + }) + newAddress := endpoints[2].Addresses[0] + // Poll for only endpoint 3 to show up. This requires a polling loop because + // picker update with all three endpoints doesn't take into effect + // immediately, needs the third pickfirst to become READY. + if err := checkRoundRobinRPCs(ctx, testServiceClient, []resolver.Address{newAddress}); err != nil { + t.Fatalf("error in expected round robin: %v", err) + } + + // Start 25 rpcs, but don't finish them. They should all start on endpoint 3, + // since the first two endpoints both have 25 RPCs (and randomness + // injection/choiceCount causes all 3 to be compared every iteration). + for i := 0; i < 25; i++ { + stream, err := testServiceClient.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err) + } + p, ok := peer.FromContext(stream.Context()) + if !ok { + t.Fatalf("testServiceClient.FullDuplexCall has no Peer") + } + if p.Addr.String() != newAddress.Addr { + t.Fatalf("testServiceClient.FullDuplexCall's Peer got: %v, want: %v", p.Addr.String(), newAddress) + } + } + + // Now 25 RPC's are active on each endpoint, the next three RPC's should + // round robin, since choiceCount is three and the injected random indexes + // cause it to search all three endpoints for fewest outstanding requests on + // each iteration. + wantAddrCount := map[string]int{ + endpoints[0].Addresses[1].Addr: 1, + endpoints[1].Addresses[0].Addr: 1, + endpoints[2].Addresses[0].Addr: 1, + } + gotAddrCount := make(map[string]int) + for i := 0; i < len(endpoints); i++ { + stream, err := testServiceClient.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("testServiceClient.FullDuplexCall failed: %v", err) + } + p, ok := peer.FromContext(stream.Context()) + if !ok { + t.Fatalf("testServiceClient.FullDuplexCall has no Peer") + } + if p.Addr != nil { + gotAddrCount[p.Addr.String()]++ + } + } + if diff := cmp.Diff(gotAddrCount, wantAddrCount); diff != "" { + t.Fatalf("addr count (-got:, +want): %v", diff) + } + + // Shutdown the active address for endpoint[1] and endpoint[2]. This should + // result in their streams failing. Now the requests should roundrobin b/w + // endpoint[1] and endpoint[2]. + backends[2].Stop() + backends[4].Stop() + index = 0 + indexes = []uint32{ + 0, 1, 2, 2, 1, 0, + } + wantAddrs = []resolver.Address{ + endpoints[1].Addresses[1], + endpoints[2].Addresses[1], + } + if err := checkRoundRobinRPCs(ctx, testServiceClient, wantAddrs); err != nil { + t.Fatalf("error in expected round robin: %v", err) + } } diff --git a/balancer/leastrequest/leastrequest.go b/balancer/leastrequest/leastrequest.go index 6dede1a40b70..0157d7ee7fad 100644 --- a/balancer/leastrequest/leastrequest.go +++ b/balancer/leastrequest/leastrequest.go @@ -23,21 +23,28 @@ import ( "encoding/json" "fmt" rand "math/rand/v2" + "sync" "sync/atomic" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/endpointsharding" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) -// randuint32 is a global to stub out in tests. -var randuint32 = rand.Uint32 - // Name is the name of the least request balancer. const Name = "least_request_experimental" -var logger = grpclog.Component("least-request") +var ( + // randuint32 is a global to stub out in tests. + randuint32 = rand.Uint32 + endpointShardingLBConfig = endpointsharding.PickFirstConfig + logger = grpclog.Component("least-request") +) func init() { balancer.Register(bb{}) @@ -80,9 +87,13 @@ func (bb) Name() string { } func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { - b := &leastRequestBalancer{scRPCCounts: make(map[balancer.SubConn]*atomic.Int32)} - baseBuilder := base.NewBalancerBuilder(Name, b, base.Config{HealthCheck: true}) - b.Balancer = baseBuilder.Build(cc, bOpts) + b := &leastRequestBalancer{ + ClientConn: cc, + endpointRPCCounts: resolver.NewEndpointMap(), + } + b.child = endpointsharding.NewBalancer(b, bOpts) + b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b)) + b.logger.Infof("Created") return b } @@ -90,94 +101,143 @@ type leastRequestBalancer struct { // Embeds balancer.Balancer because needs to intercept UpdateClientConnState // to learn about choiceCount. balancer.Balancer + // Embeds balancer.ClientConn because needs to intercept UpdateState calls + // from the child balancer. + balancer.ClientConn + child balancer.Balancer + logger *internalgrpclog.PrefixLogger + mu sync.Mutex choiceCount uint32 - scRPCCounts map[balancer.SubConn]*atomic.Int32 // Hold onto RPC counts to keep track for subsequent picker updates. + // endpointRPCCounts holds RPC counts to keep track for subsequent picker + // updates. + endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32 +} + +func (lrb *leastRequestBalancer) Close() { + lrb.child.Close() + lrb.endpointRPCCounts = nil } -func (lrb *leastRequestBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - lrCfg, ok := s.BalancerConfig.(*LBConfig) +func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { + lrCfg, ok := ccs.BalancerConfig.(*LBConfig) if !ok { - logger.Errorf("least-request: received config with unexpected type %T: %v", s.BalancerConfig, s.BalancerConfig) + logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig) return balancer.ErrBadResolverState } + lrb.mu.Lock() lrb.choiceCount = lrCfg.ChoiceCount - return lrb.Balancer.UpdateClientConnState(s) + lrb.mu.Unlock() + // Enable the health listener in pickfirst children for client side health + // checks and outlier detection, if configured. + ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState) + ccs.BalancerConfig = endpointShardingLBConfig + return lrb.child.UpdateClientConnState(ccs) } -type scWithRPCCount struct { - sc balancer.SubConn +type endpointState struct { + picker balancer.Picker numRPCs *atomic.Int32 } -func (lrb *leastRequestBalancer) Build(info base.PickerBuildInfo) balancer.Picker { - if logger.V(2) { - logger.Infof("least-request: Build called with info: %v", info) +func (lrb *leastRequestBalancer) UpdateState(state balancer.State) { + var readyEndpoints []endpointsharding.ChildState + for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) { + if child.State.ConnectivityState == connectivity.Ready { + readyEndpoints = append(readyEndpoints, child) + } } - if len(info.ReadySCs) == 0 { - return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + + // If no ready pickers are present, simply defer to the round robin picker + // from endpoint sharding, which will round robin across the most relevant + // pick first children in the highest precedence connectivity state. + if len(readyEndpoints) == 0 { + lrb.ClientConn.UpdateState(state) + return } - for sc := range lrb.scRPCCounts { - if _, ok := info.ReadySCs[sc]; !ok { // If no longer ready, no more need for the ref to count active RPCs. - delete(lrb.scRPCCounts, sc) - } + lrb.mu.Lock() + defer lrb.mu.Unlock() + + if logger.V(2) { + lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints) } - // Create new refs if needed. - for sc := range info.ReadySCs { - if _, ok := lrb.scRPCCounts[sc]; !ok { - lrb.scRPCCounts[sc] = new(atomic.Int32) + // Reconcile endpoints. + newEndpoints := resolver.NewEndpointMap() // endpoint -> nil + for _, child := range readyEndpoints { + newEndpoints.Set(child.Endpoint, nil) + } + + // If endpoints are no longer ready, no need to count their active RPCs. + for _, endpoint := range lrb.endpointRPCCounts.Keys() { + if _, ok := newEndpoints.Get(endpoint); !ok { + lrb.endpointRPCCounts.Delete(endpoint) } } // Copy refs to counters into picker. - scs := make([]scWithRPCCount, 0, len(info.ReadySCs)) - for sc := range info.ReadySCs { - scs = append(scs, scWithRPCCount{ - sc: sc, - numRPCs: lrb.scRPCCounts[sc], // guaranteed to be present due to algorithm + endpointStates := make([]endpointState, 0, len(readyEndpoints)) + for _, child := range readyEndpoints { + var counter *atomic.Int32 + if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok { + // Create new counts if needed. + counter = new(atomic.Int32) + lrb.endpointRPCCounts.Set(child.Endpoint, counter) + } else { + counter = val.(*atomic.Int32) + } + endpointStates = append(endpointStates, endpointState{ + picker: child.State.Picker, + numRPCs: counter, }) } - return &picker{ - choiceCount: lrb.choiceCount, - subConns: scs, - } + lrb.ClientConn.UpdateState(balancer.State{ + Picker: &picker{ + choiceCount: lrb.choiceCount, + endpointStates: endpointStates, + }, + ConnectivityState: connectivity.Ready, + }) } type picker struct { - // choiceCount is the number of random SubConns to find the one with - // the least request. - choiceCount uint32 - // Built out when receives list of ready RPCs. - subConns []scWithRPCCount + // choiceCount is the number of random endpoints to sample for choosing the + // one with the least requests. + choiceCount uint32 + endpointStates []endpointState } -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - var pickedSC *scWithRPCCount - var pickedSCNumRPCs int32 +func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) { + var pickedEndpointState *endpointState + var pickedEndpointNumRPCs int32 for i := 0; i < int(p.choiceCount); i++ { - index := randuint32() % uint32(len(p.subConns)) - sc := p.subConns[index] - n := sc.numRPCs.Load() - if pickedSC == nil || n < pickedSCNumRPCs { - pickedSC = &sc - pickedSCNumRPCs = n + index := randuint32() % uint32(len(p.endpointStates)) + endpointState := p.endpointStates[index] + n := endpointState.numRPCs.Load() + if pickedEndpointState == nil || n < pickedEndpointNumRPCs { + pickedEndpointState = &endpointState + pickedEndpointNumRPCs = n } } + result, err := pickedEndpointState.picker.Pick(pInfo) + if err != nil { + return result, err + } // "The counter for a subchannel should be atomically incremented by one // after it has been successfully picked by the picker." - A48 - pickedSC.numRPCs.Add(1) + pickedEndpointState.numRPCs.Add(1) // "the picker should add a callback for atomically decrementing the // subchannel counter once the RPC finishes (regardless of Status code)." - // A48. - done := func(balancer.DoneInfo) { - pickedSC.numRPCs.Add(-1) + originalDone := result.Done + result.Done = func(info balancer.DoneInfo) { + pickedEndpointState.numRPCs.Add(-1) + if originalDone != nil { + originalDone(info) + } } - return balancer.PickResult{ - SubConn: pickedSC.sc, - Done: done, - }, nil + return result, nil }