diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index bdd9035ead..d416f6c195 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -520,6 +520,7 @@ func (s *Server) update() { } } + timeoutCnt := 0 for { // Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we // can safely read from it again because Disconnect closes the channel. @@ -545,18 +546,42 @@ func (s *Server) update() { continue } - // Must hold the processErrorLock while updating the server description and clearing the - // pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and - // pool.ready() calls from concurrent server description updates. - s.processErrorLock.Lock() - s.updateDescription(desc) - if err := desc.LastError; err != nil { - // Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear - // because the monitoring routine only runs for non-load balanced deployments in which servers don't return - // IDs. - s.pool.clear(err, nil) + if isShortcut := func() bool { + // Must hold the processErrorLock while updating the server description and clearing the + // pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and + // pool.ready() calls from concurrent server description updates. + s.processErrorLock.Lock() + defer s.processErrorLock.Unlock() + + s.updateDescription(desc) + // Retry after the first timeout before clearing the pool in case of a FAAS pause as + // described in GODRIVER-2577. + if err := unwrapConnectionError(desc.LastError); err != nil && timeoutCnt < 1 { + if err == context.Canceled || err == context.DeadlineExceeded { + timeoutCnt++ + // We want to immediately retry on timeout error. Continue to next loop. + return true + } + if err, ok := err.(net.Error); ok && err.Timeout() { + timeoutCnt++ + // We want to immediately retry on timeout error. Continue to next loop. + return true + } + } + if err := desc.LastError; err != nil { + // Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear + // because the monitoring routine only runs for non-load balanced deployments in which servers don't return + // IDs. + s.pool.clear(err, nil) + } + // We're either not handling a timeout error, or we just handled the 2nd consecutive + // timeout error. In either case, reset the timeout count to 0 and return false to + // continue the normal check process. + timeoutCnt = 0 + return false + }(); isShortcut { + continue } - s.processErrorLock.Unlock() // If the server supports streaming or we're already streaming, we want to move to streaming the next response // without waiting. If the server has transitioned to Unknown from a network error, we want to do another @@ -707,19 +732,31 @@ func (s *Server) check() (description.Server, error) { var err error var durationNanos int64 - // Create a new connection if this is the first check, the connection was closed after an error during the previous - // check, or the previous check was cancelled. + start := time.Now() if s.conn == nil || s.conn.closed() || s.checkWasCancelled() { + // Create a new connection if this is the first check, the connection was closed after an error during the previous + // check, or the previous check was cancelled. + isNilConn := s.conn == nil + if !isNilConn { + s.publishServerHeartbeatStartedEvent(s.conn.ID(), false) + } // Create a new connection and add it's handshake RTT as a sample. err = s.setupHeartbeatConnection() + durationNanos = time.Since(start).Nanoseconds() if err == nil { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.helloRTT) descPtr = &s.conn.desc + if !isNilConn { + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, s.conn.desc, false) + } + } else { + err = unwrapConnectionError(err) + if !isNilConn { + s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, false) + } } - } - - if descPtr == nil && err == nil { + } else { // An existing connection is being used. Use the server description properties to execute the right heartbeat. // Wrap conn in a type that implements driver.StreamerConnection. @@ -729,7 +766,6 @@ func (s *Server) check() (description.Server, error) { streamable := previousDescription.TopologyVersion != nil s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) - start := time.Now() switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 9850665db1..ecb001e311 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -11,8 +11,12 @@ package topology import ( "context" + "crypto/tls" + "crypto/x509" "errors" + "io/ioutil" "net" + "os" "runtime" "sync" "sync/atomic" @@ -49,6 +53,144 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n return cnc, nil } +type errorQueue struct { + errors []error + mutex sync.Mutex +} + +func (eq *errorQueue) head() error { + eq.mutex.Lock() + defer eq.mutex.Unlock() + if len(eq.errors) > 0 { + return eq.errors[0] + } + return nil +} + +func (eq *errorQueue) dequeue() bool { + eq.mutex.Lock() + defer eq.mutex.Unlock() + if len(eq.errors) > 0 { + eq.errors = eq.errors[1:] + return true + } + return false +} + +type timeoutConn struct { + net.Conn + errors *errorQueue +} + +func (c *timeoutConn) Read(b []byte) (int, error) { + n, err := 0, c.errors.head() + if err == nil { + n, err = c.Conn.Read(b) + } + return n, err +} + +func (c *timeoutConn) Write(b []byte) (int, error) { + n, err := 0, c.errors.head() + if err == nil { + n, err = c.Conn.Write(b) + } + return n, err +} + +type timeoutDialer struct { + Dialer + errors *errorQueue +} + +func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + c, e := d.Dialer.DialContext(ctx, network, address) + + if caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE"); len(caFile) > 0 { + pem, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + + ca := x509.NewCertPool() + if !ca.AppendCertsFromPEM(pem) { + return nil, errors.New("unable to load CA file") + } + + config := &tls.Config{ + InsecureSkipVerify: true, + RootCAs: ca, + } + c = tls.Client(c, config) + } + return &timeoutConn{c, d.errors}, e +} + +// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577. +func TestServerHeartbeatTimeout(t *testing.T) { + networkTimeoutError := &net.DNSError{ + IsTimeout: true, + } + + testCases := []struct { + desc string + ioErrors []error + expectPoolCleared bool + }{ + { + desc: "one single timeout should not clear the pool", + ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil}, + expectPoolCleared: false, + }, + { + desc: "continuous timeouts should clear the pool", + ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil}, + expectPoolCleared: true, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + var wg sync.WaitGroup + wg.Add(1) + + errors := &errorQueue{errors: tc.ioErrors} + tpm := monitor.NewTestPoolMonitor() + server := NewServer( + address.Address("localhost:27017"), + primitive.NewObjectID(), + WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { + return tpm.PoolMonitor + }), + WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { + return append(opts, + WithDialer(func(d Dialer) Dialer { + var dialer net.Dialer + return &timeoutDialer{&dialer, errors} + })) + }), + WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { + return &event.ServerMonitor{ + ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { + if !errors.dequeue() { + wg.Done() + } + }, + } + }), + WithHeartbeatInterval(func(time.Duration) time.Duration { + return 200 * time.Millisecond + }), + ) + require.NoError(t, server.Connect(nil)) + wg.Wait() + assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared()) + }) + } +} + // TestServerConnectionTimeout tests how different timeout errors are handled during connection // creation and server handshake. func TestServerConnectionTimeout(t *testing.T) {