diff --git a/net.go b/net.go index f6a0d45fe..068d8e1ad 100644 --- a/net.go +++ b/net.go @@ -522,7 +522,7 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { // Send the ping. addr := joinHostPort(net.IP(ind.Target).String(), ind.Port) if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from)) + m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogAddress(from)) } // Setup a timer to fire off a nack if no ack is seen in time. diff --git a/state.go b/state.go index 1af62943e..f5ed65a78 100644 --- a/state.go +++ b/state.go @@ -6,6 +6,7 @@ import ( "math" "math/rand" "net" + "strings" "sync/atomic" "time" @@ -242,6 +243,21 @@ func (m *Memberlist) probeNodeByAddr(addr string) { m.probeNode(n) } +// failedRemote checks the error and decides if it indicates a failure on the +// other end. +func failedRemote(err error) bool { + switch t := err.(type) { + case *net.OpError: + if strings.HasPrefix(t.Net, "tcp") { + switch t.Op { + case "dial", "read", "write": + return true + } + } + } + return false +} + // probeNode handles a single round of failure checking on a node. func (m *Memberlist) probeNode(node *nodeState) { defer metrics.MeasureSince([]string{"memberlist", "probeNode"}, time.Now()) @@ -272,10 +288,20 @@ func (m *Memberlist) probeNode(node *nodeState) { // soon as possible. deadline := sent.Add(probeInterval) addr := node.Address() + + // Arrange for our self-awareness to get updated. + var awarenessDelta int + defer func() { + m.awareness.ApplyDelta(awarenessDelta) + }() if node.State == stateAlive { if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err) - return + if failedRemote(err) { + goto HANDLE_REMOTE_FAILURE + } else { + return + } } } else { var msgs [][]byte @@ -296,7 +322,11 @@ func (m *Memberlist) probeNode(node *nodeState) { compound := makeCompoundMessage(msgs) if err := m.rawSendMsgPacket(addr, &node.Node, compound.Bytes()); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send compound ping and suspect message to %s: %s", addr, err) - return + if failedRemote(err) { + goto HANDLE_REMOTE_FAILURE + } else { + return + } } } @@ -305,10 +335,7 @@ func (m *Memberlist) probeNode(node *nodeState) { // which will improve our health until we get to the failure scenarios // at the end of this function, which will alter this delta variable // accordingly. - awarenessDelta := -1 - defer func() { - m.awareness.ApplyDelta(awarenessDelta) - }() + awarenessDelta = -1 // Wait for response or round-trip-time. select { @@ -333,9 +360,10 @@ func (m *Memberlist) probeNode(node *nodeState) { // probe interval it will give the TCP fallback more time, which // is more active in dealing with lost packets, and it gives more // time to wait for indirect acks/nacks. - m.logger.Printf("[DEBUG] memberlist: Failed ping: %v (timeout reached)", node.Name) + m.logger.Printf("[DEBUG] memberlist: Failed ping: %s (timeout reached)", node.Name) } +HANDLE_REMOTE_FAILURE: // Get some random live nodes. m.nodeLock.RLock() kNodes := kRandomNodes(m.config.IndirectChecks, m.nodes, func(n *nodeState) bool { diff --git a/state_test.go b/state_test.go index def689448..204e03dca 100644 --- a/state_test.go +++ b/state_test.go @@ -2118,6 +2118,37 @@ func TestMemberlist_GossipToDead(t *testing.T) { }) } +func TestMemberlist_FailedRemote(t *testing.T) { + type test struct { + name string + err error + expected bool + } + tests := []test{ + {"nil error", nil, false}, + {"normal error", fmt.Errorf(""), false}, + {"net.OpError for file", &net.OpError{Net: "file"}, false}, + {"net.OpError for udp", &net.OpError{Net: "udp"}, false}, + {"net.OpError for udp4", &net.OpError{Net: "udp4"}, false}, + {"net.OpError for udp6", &net.OpError{Net: "udp6"}, false}, + {"net.OpError for tcp", &net.OpError{Net: "tcp"}, false}, + {"net.OpError for tcp4", &net.OpError{Net: "tcp4"}, false}, + {"net.OpError for tcp6", &net.OpError{Net: "tcp6"}, false}, + {"net.OpError for tcp with dial", &net.OpError{Net: "tcp", Op: "dial"}, true}, + {"net.OpError for tcp with write", &net.OpError{Net: "tcp", Op: "write"}, true}, + {"net.OpError for tcp with read", &net.OpError{Net: "tcp", Op: "read"}, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := failedRemote(test.err) + if actual != test.expected { + t.Fatalf("expected %t, got %t", test.expected, actual) + } + }) + } +} + func TestMemberlist_PushPull(t *testing.T) { addr1 := getBindAddr() addr2 := getBindAddr()