diff --git a/balancer/balancer.go b/balancer/balancer.go index f391744f7299..04e47a1254ba 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -72,8 +72,21 @@ func unregisterForTesting(name string) { delete(m, name) } +// connectedAddress returns the connected address for a SubConnState. The +// address is only valid if the state is READY. +func connectedAddress(scs SubConnState) resolver.Address { + return scs.connectedAddress +} + +// setConnectedAddress sets the connected address for a SubConnState. +func setConnectedAddress(scs *SubConnState, addr resolver.Address) { + scs.connectedAddress = addr +} + func init() { internal.BalancerUnregister = unregisterForTesting + internal.ConnectedAddress = connectedAddress + internal.SetConnectedAddress = setConnectedAddress } // Get returns the resolver builder registered with the given name. @@ -410,6 +423,9 @@ type SubConnState struct { // ConnectionError is set if the ConnectivityState is TransientFailure, // describing the reason the SubConn failed. Otherwise, it is nil. ConnectionError error + // connectedAddr contains the connected address when ConnectivityState is + // Ready. Otherwise, it is indeterminate. + connectedAddress resolver.Address } // ClientConnState describes the state of a ClientConn relevant to the diff --git a/balancer_wrapper.go b/balancer_wrapper.go index 4161fdf47a8b..554fd3c64afe 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -25,12 +25,15 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/resolver" ) +var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) + // ccBalancerWrapper sits between the ClientConn and the Balancer. // // ccBalancerWrapper implements methods corresponding to the ones on the @@ -252,7 +255,7 @@ type acBalancerWrapper struct { // updateState is invoked by grpc to push a subConn state update to the // underlying balancer. -func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { +func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) { acbw.ccb.serializer.Schedule(func(ctx context.Context) { if ctx.Err() != nil || acbw.ccb.balancer == nil { return @@ -260,7 +263,11 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { // Even though it is optional for balancers, gracefulswitch ensures // opts.StateListener is set, so this cannot ever be nil. // TODO: delete this comment when UpdateSubConnState is removed. - acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) + scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err} + if s == connectivity.Ready { + setConnectedAddress(&scs, curAddr) + } + acbw.stateListener(scs) }) } diff --git a/clientconn.go b/clientconn.go index 423be7b43b00..663902ae8308 100644 --- a/clientconn.go +++ b/clientconn.go @@ -24,6 +24,7 @@ import ( "fmt" "math" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -812,17 +813,11 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) { cc.csMgr.updateState(connectivity.TransientFailure) } -// Makes a copy of the input addresses slice and clears out the balancer -// attributes field. Addresses are passed during subconn creation and address -// update operations. In both cases, we will clear the balancer attributes by -// calling this function, and therefore we will be able to use the Equal method -// provided by the resolver.Address type for comparison. -func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address { +// Makes a copy of the input addresses slice. Addresses are passed during +// subconn creation and address update operations. +func copyAddresses(in []resolver.Address) []resolver.Address { out := make([]resolver.Address, len(in)) - for i := range in { - out[i] = in[i] - out[i].BalancerAttributes = nil - } + copy(out, in) return out } @@ -837,7 +832,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer. ac := &addrConn{ state: connectivity.Idle, cc: cc, - addrs: copyAddressesWithoutBalancerAttributes(addrs), + addrs: copyAddresses(addrs), scopts: opts, dopts: cc.dopts, channelz: channelz.RegisterSubChannel(cc.channelz, ""), @@ -924,22 +919,24 @@ func (ac *addrConn) connect() error { return nil } -func equalAddresses(a, b []resolver.Address) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if !v.Equal(b[i]) { - return false - } - } - return true +// equalAddressIgnoringBalAttributes returns true is a and b are considered equal. +// This is different from the Equal method on the resolver.Address type which +// considers all fields to determine equality. Here, we only consider fields +// that are meaningful to the subConn. +func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { + return a.Addr == b.Addr && a.ServerName == b.ServerName && + a.Attributes.Equal(b.Attributes) && + a.Metadata == b.Metadata +} + +func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool { + return slices.EqualFunc(a, b, func(a, b resolver.Address) bool { return equalAddressIgnoringBalAttributes(&a, &b) }) } // updateAddrs updates ac.addrs with the new addresses list and handles active // connections or connection attempts. func (ac *addrConn) updateAddrs(addrs []resolver.Address) { - addrs = copyAddressesWithoutBalancerAttributes(addrs) + addrs = copyAddresses(addrs) limit := len(addrs) if limit > 5 { limit = 5 @@ -947,7 +944,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit]) ac.mu.Lock() - if equalAddresses(ac.addrs, addrs) { + if equalAddressesIgnoringBalAttributes(ac.addrs, addrs) { ac.mu.Unlock() return } @@ -966,7 +963,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { // Try to find the connected address. for _, a := range addrs { a.ServerName = ac.cc.getServerName(a) - if a.Equal(ac.curAddr) { + if equalAddressIgnoringBalAttributes(&a, &ac.curAddr) { // We are connected to a valid address, so do nothing but // update the addresses. ac.mu.Unlock() @@ -1214,7 +1211,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error) } else { channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) } - ac.acbw.updateState(s, lastErr) + ac.acbw.updateState(s, ac.curAddr, lastErr) } // adjustParams updates parameters used to create transports upon diff --git a/internal/internal.go b/internal/internal.go index 5d6653986923..46ed257685bf 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -208,6 +208,13 @@ var ( // ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n // is the number of elements. swap swaps the elements with indexes i and j. ShuffleAddressListForTesting any // func(n int, swap func(i, j int)) + + // ConnectedAddress returns the connected address for a SubConnState. The + // address is only valid if the state is READY. + ConnectedAddress any // func (scs SubConnState) resolver.Address + + // SetConnectedAddress sets the connected address for a SubConnState. + SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address) ) // HealthChecker defines the signature of the client-side LB channel health diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 76c96decfd7d..5a4bb0f270b2 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/grpctest" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" @@ -637,7 +638,10 @@ func (s) TestLoadReporting(t *testing.T) { t.Fatal(err.Error()) } - sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + scs := balancer.SubConnState{ConnectivityState: connectivity.Ready} + sca := internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) + sca(&scs, addrs[0]) + sc1.UpdateState(scs) // Test pick with one backend. const successCount = 5 const errorCount = 5 diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 164f3099d280..9058f0d01fc8 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/grpclog" @@ -52,6 +53,8 @@ const ( defaultRequestCountMax = 1024 ) +var connectedAddress = internal.ConnectedAddress.(func(balancer.SubConnState) resolver.Address) + func init() { balancer.Register(bb{}) } @@ -360,22 +363,35 @@ func (scw *scWrapper) localityID() xdsinternal.LocalityID { func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { clusterName := b.getClusterName() newAddrs := make([]resolver.Address, len(addrs)) - var lID xdsinternal.LocalityID for i, addr := range addrs { newAddrs[i] = xds.SetXDSHandshakeClusterName(addr, clusterName) - lID = xdsinternal.GetLocalityID(newAddrs[i]) } var sc balancer.SubConn + scw := &scWrapper{} oldListener := opts.StateListener - opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state, oldListener) } + opts.StateListener = func(state balancer.SubConnState) { + b.updateSubConnState(sc, state, oldListener) + if state.ConnectivityState != connectivity.Ready { + return + } + // Read connected address and call updateLocalityID() based on the connected + // address's locality. https://github.com/grpc/grpc-go/issues/7339 + addr := connectedAddress(state) + lID := xdsinternal.GetLocalityID(addr) + if lID.Empty() { + if b.logger.V(2) { + b.logger.Infof("Locality ID for %s unexpectedly empty", addr) + } + return + } + scw.updateLocalityID(lID) + } sc, err := b.ClientConn.NewSubConn(newAddrs, opts) if err != nil { return nil, err } - // Wrap this SubConn in a wrapper, and add it to the map. - ret := &scWrapper{SubConn: sc} - ret.updateLocalityID(lID) - return ret, nil + scw.SubConn = sc + return scw, nil } func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) { diff --git a/xds/internal/balancer/clusterimpl/tests/balancer_test.go b/xds/internal/balancer/clusterimpl/tests/balancer_test.go index 4a5c13b8e6b4..d2a6b6d7f757 100644 --- a/xds/internal/balancer/clusterimpl/tests/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/tests/balancer_test.go @@ -310,14 +310,9 @@ func (s) TestLoadReportingPickFirstMultiLocality(t *testing.T) { } mgmtServer.LRSServer.LRSResponseChan <- &resp - // Wait for load to be reported for locality of server 2. - // We (incorrectly) wait for load report for region-2 because presently - // pickfirst always reports load for the locality of the last address in the - // subconn. This will be fixed by ensuring there is only one address per - // subconn. - // TODO(#7339): Change region to region-1 once fixed. - if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-2"); err != nil { - t.Fatalf("region-2 did not receive load due to error: %v", err) + // Wait for load to be reported for locality of server 1. + if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-1"); err != nil { + t.Fatalf("Server 1 did not receive load due to error: %v", err) } // Stop server 1 and send one more rpc. Now the request should go to server 2. diff --git a/xds/internal/balancer/outlierdetection/balancer_test.go b/xds/internal/balancer/outlierdetection/balancer_test.go index 54eefaa34c1a..39bd51aa6567 100644 --- a/xds/internal/balancer/outlierdetection/balancer_test.go +++ b/xds/internal/balancer/outlierdetection/balancer_test.go @@ -852,7 +852,7 @@ func (s) TestUpdateAddresses(t *testing.T) { } func scwsEqual(gotSCWS subConnWithState, wantSCWS subConnWithState) error { - if gotSCWS.sc != wantSCWS.sc || !cmp.Equal(gotSCWS.state, wantSCWS.state, cmp.AllowUnexported(subConnWrapper{}, addressInfo{}), cmpopts.IgnoreFields(subConnWrapper{}, "scUpdateCh")) { + if gotSCWS.sc != wantSCWS.sc || !cmp.Equal(gotSCWS.state, wantSCWS.state, cmp.AllowUnexported(subConnWrapper{}, addressInfo{}, balancer.SubConnState{}), cmpopts.IgnoreFields(subConnWrapper{}, "scUpdateCh")) { return fmt.Errorf("received SubConnState: %+v, want %+v", gotSCWS, wantSCWS) } return nil diff --git a/xds/internal/internal.go b/xds/internal/internal.go index 7091990500f9..1d8a6b03f1b3 100644 --- a/xds/internal/internal.go +++ b/xds/internal/internal.go @@ -55,6 +55,11 @@ func (l LocalityID) Equal(o any) bool { return l.Region == ol.Region && l.Zone == ol.Zone && l.SubZone == ol.SubZone } +// Empty returns whether or not the locality ID is empty. +func (l LocalityID) Empty() bool { + return l.Region == "" && l.Zone == "" && l.SubZone == "" +} + // LocalityIDFromString converts a json representation of locality, into a // LocalityID struct. func LocalityIDFromString(s string) (ret LocalityID, _ error) {