Skip to content

Commit

Permalink
refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Oct 16, 2024
1 parent 67f7a1a commit af38951
Showing 1 changed file with 81 additions and 87 deletions.
168 changes: 81 additions & 87 deletions balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,9 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFAfterEndOfList(t *testing.T) {
}()

servers := newHangingServerGroup(t, 3)
defer servers.close()
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialState(resolver.State{Addresses: servers.addrs})
addrs := resolverAddrsFromHangingServers(servers)
rb.InitialState(resolver.State{Addresses: addrs})
cc, err := grpc.NewClient("whatever:///this-gets-overwritten",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)),
Expand All @@ -904,37 +904,37 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFAfterEndOfList(t *testing.T) {
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)

// Verify that only the first server is contacted.
if err := servers.awaitContacted(ctx, 0); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[0], err)
if err := servers[0].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[0], err)
}

// Ensure no other servers are contacted.
if got, want := servers.isContacted(1), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[1], got, want)
if got, want := servers[1].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[1], got, want)
}
if got, want := servers.isContacted(2), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[2], got, want)
if got, want := servers[2].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[2], got, want)
}

// Make the happy eyeballs timer fire twice so that pickfirst reaches the
// last address in the list.
timerCh <- struct{}{}

// Verify that the second server is contacted and 3rd isn't.
if err := servers.awaitContacted(ctx, 1); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[1], err)
if err := servers[1].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[1], err)
}

if got, want := servers.isContacted(2), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[2], got, want)
if got, want := servers[2].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[2], got, want)
}
timerCh <- struct{}{}
if err := servers.awaitContacted(ctx, 2); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[2], err)
if err := servers[2].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[2], err)
}

// First SubConn Fails.
servers.closeConn(0)
servers[0].closeConn()

// No TF should be reported until the first pass is complete.
shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
Expand All @@ -947,12 +947,12 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFAfterEndOfList(t *testing.T) {
timerCh <- struct{}{}

// Third SubConn fails.
servers.closeConn(2)
servers[2].closeConn()

testutils.AwaitNotState(shortCtx, t, cc, connectivity.TransientFailure)

// Last SubConn fails, this should result in a TF update.
servers.closeConn(1)
servers[1].closeConn()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
}

Expand Down Expand Up @@ -988,9 +988,9 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TriggerConnectionDelay(t *testing.T) {
}()

servers := newHangingServerGroup(t, 2)
defer servers.close()
addrs := resolverAddrsFromHangingServers(servers)
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialState(resolver.State{Addresses: servers.addrs})
rb.InitialState(resolver.State{Addresses: addrs})
cc, err := grpc.NewClient("whatever:///this-gets-overwritten",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)),
Expand All @@ -1005,21 +1005,21 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TriggerConnectionDelay(t *testing.T) {
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)

// Verify that the first server is contacted.
if err := servers.awaitContacted(ctx, 0); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[0], err)
if err := servers[0].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[0], err)
}

if got, want := servers.isContacted(1), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[1], got, want)
if got, want := servers[1].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[1], got, want)
}

timerCh <- struct{}{}

// Second connection attempt is successful.
if err := servers.awaitContacted(ctx, 1); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[1], err)
if err := servers[1].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[1], err)
}
servers.enterReady(1)
servers[1].enterReady()
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
}

Expand Down Expand Up @@ -1055,9 +1055,9 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFThenTimerFires(t *testing.T) {
}()

servers := newHangingServerGroup(t, 3)
defer servers.close()
addrs := resolverAddrsFromHangingServers(servers)
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialState(resolver.State{Addresses: servers.addrs})
rb.InitialState(resolver.State{Addresses: addrs})
cc, err := grpc.NewClient("whatever:///this-gets-overwritten",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)),
Expand All @@ -1072,16 +1072,16 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFThenTimerFires(t *testing.T) {
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)

// Verify that only the first server is contacted.
if err := servers.awaitContacted(ctx, 0); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[0], err)
if err := servers[0].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[0], err)
}

// Ensure no other servers are contacted.
if got, want := servers.isContacted(1), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[1], got, want)
if got, want := servers[1].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[1], got, want)
}
if got, want := servers.isContacted(2), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[2], got, want)
if got, want := servers[2].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[2], got, want)
}

// First SubConn Fails.
Expand All @@ -1090,29 +1090,29 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TFThenTimerFires(t *testing.T) {
timerMu.Lock()
timerCh = make(chan struct{})
timerMu.Unlock()
servers.closeConn(0)
servers[0].closeConn()

// The second server is contacted.
// Verify that only the first server is contacted.
if err := servers.awaitContacted(ctx, 1); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[1], err)
if err := servers[1].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[1], err)
}

// Ensure no other servers are contacted.
if got, want := servers.isContacted(2), false; got != want {
t.Fatalf("Servers.isContacted(%q) = %t, want %t", servers.addrs[2], got, want)
if got, want := servers[2].isContacted(), false; got != want {
t.Fatalf("Server[%q].isContacted() = %t, want %t", addrs[2], got, want)
}

// The happy eyeballs timer expires, skipping server[1] and requesting the creation
// of a third SubConn.
timerCh <- struct{}{}

if err := servers.awaitContacted(ctx, 2); err != nil {
t.Fatalf("Server with address %q not contacted: %v", servers.addrs[2], err)
if err := servers[2].awaitContacted(ctx); err != nil {
t.Fatalf("Server with address %q not contacted: %v", addrs[2], err)
}

// Second SubConn connects.
servers.enterReady(1)
servers[1].enterReady()
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)
}

Expand Down Expand Up @@ -1221,37 +1221,27 @@ func (c *ccStateSubscriber) OnMessage(msg any) {
c.transitions = append(c.transitions, msg.(connectivity.State))
}

// handingServerGroup is a group of servers that accept a TCP connection and
// remain idle until asked to close the connection. They can be used to control
// handingServe is a server that accept a TCP connection and remains idle until
// asked to close or respond to the connection. They can be used to control
// how long it takes for a subchannel to report a TRANSIENT_FAILURE in tests.
type handingServerGroup struct {
addrs []resolver.Address
listeners []net.Listener
serverConnCloseFuncs []func()
serverContacted []*grpcsync.Event
readyChans []chan struct{}
type hangingServer struct {
addr resolver.Address
listener net.Listener
closeConn func()
contacted *grpcsync.Event
readyChan chan struct{}
}

func newHangingServerGroup(t *testing.T, count int) *handingServerGroup {
listeners := []net.Listener{}
closeFns := []func(){}
addrs := []resolver.Address{}
contactedEvents := []*grpcsync.Event{}
readyChans := []chan struct{}{}

func newHangingServerGroup(t *testing.T, count int) []*hangingServer {
servers := []*hangingServer{}
for i := 0; i < count; i++ {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
listeners = append(listeners, lis)
addrs = append(addrs, resolver.Address{Addr: lis.Addr().String()})
closeChan := make(chan struct{})
closeFns = append(closeFns, sync.OnceFunc(func() { close(closeChan) }))
contacted := grpcsync.NewEvent()
contactedEvents = append(contactedEvents, contacted)
readyChan := make(chan struct{})
readyChans = append(readyChans, readyChan)

go func() {
conn, err := lis.Accept()
Expand All @@ -1277,43 +1267,47 @@ func newHangingServerGroup(t *testing.T, count int) *handingServerGroup {
}
}
}()
}

return &handingServerGroup{
addrs: addrs,
listeners: listeners,
serverConnCloseFuncs: closeFns,
serverContacted: contactedEvents,
readyChans: readyChans,
}
}
server := &hangingServer{
addr: resolver.Address{Addr: lis.Addr().String()},
listener: lis,
closeConn: grpcsync.OnceFunc(func() {
close(closeChan)
}),
contacted: contacted,
readyChan: readyChan,
}
servers = append(servers, server)

func (hg *handingServerGroup) close() {
for _, fn := range hg.serverConnCloseFuncs {
fn()
}
for _, l := range hg.listeners {
l.Close()
t.Cleanup(func() {
server.closeConn()
server.listener.Close()
})
}
}

func (hg *handingServerGroup) closeConn(serverIdx int) {
hg.serverConnCloseFuncs[serverIdx]()
return servers
}

func (hg *handingServerGroup) isContacted(serverIdx int) bool {
return hg.serverContacted[serverIdx].HasFired()
func (s *hangingServer) isContacted() bool {
return s.contacted.HasFired()
}

func (hg *handingServerGroup) enterReady(serverIdx int) {
hg.readyChans[serverIdx] <- struct{}{}
func (s *hangingServer) enterReady() {
s.readyChan <- struct{}{}
}

func (hg *handingServerGroup) awaitContacted(ctx context.Context, serverIdx int) error {
func (s *hangingServer) awaitContacted(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-hg.serverContacted[serverIdx].Done():
case <-s.contacted.Done():
}
return nil
}

func resolverAddrsFromHangingServers(servers []*hangingServer) []resolver.Address {
addrs := []resolver.Address{}
for _, srv := range servers {
addrs = append(addrs, srv.addr)
}
return addrs
}

0 comments on commit af38951

Please sign in to comment.