Skip to content

Commit

Permalink
GODRIVER-1658 Fix connection updates to topology (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
Divjot Arora committed Jun 25, 2020
1 parent 8064395 commit 0eb881e
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 104 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ require (
golang.org/x/text v0.3.2 // indirect
golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d
)

go 1.13
3 changes: 3 additions & 0 deletions mongo/integration/mtest/mongotest.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ type FailPointData struct {
Name string `bson:"codeName"`
Errmsg string `bson:"errmsg"`
} `bson:"writeConcernError,omitempty"`
BlockConnection bool `bson:"blockConnection,omitempty"`
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
AppName string `bson:"appName,omitempty"`
}

// T is a wrapper around testing.T.
Expand Down
174 changes: 142 additions & 32 deletions mongo/integration/sdam_error_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,154 @@ func TestSDAMErrorHandling(t *testing.T) {
SetRetryWrites(false).
SetPoolMonitor(poolMonitor).
SetWriteConcern(mtest.MajorityWc)
mtOpts := mtest.NewOptions().
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
ClientOptions(clientOpts)

mt.RunOpts("network errors", mtOpts, func(mt *mtest.T) {
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
clearPoolChan()
mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"insert"},
CloseConnection: true,
},
baseMtOpts := func() *mtest.Options {
mtOpts := mtest.NewOptions().
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
ClientOptions(clientOpts)

if mt.TopologyKind() == mtest.Sharded {
// Pin to a single mongos because the tests use failpoints.
mtOpts.ClientType(mtest.Pinned)
}
return mtOpts
}

// Set min server version of 4.4 because the during-handshake tests use failpoint features introduced in 4.4 like
// blockConnection and appName.
mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) {
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
mt.Run("pool cleared on network timeout", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by an application operation thread
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
// connections created in the foreground for timeouts because connections created by the pool
// maintenance routine can't be timed out using a context.

appName := "authNetworkTimeoutTest"
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
// speculative auth.
mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"saslContinue"},
BlockConnection: true,
BlockTimeMS: 150,
AppName: appName,
},
})

// Reset the client with the appName specified in the failpoint.
clientOpts := options.Client().
SetAppName(appName).
SetRetryWrites(false).
SetPoolMonitor(poolMonitor)
mt.ResetClient(clientOpts)
clearPoolChan()

// The saslContinue blocks for 150ms so run the InsertOne with a 100ms context to cause a network
// timeout during auth and assert that the pool was cleared.
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
defer cancel()
_, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})
mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {
mt.Run("background", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by the background pool maintenance
// routine encounters a non-timeout network error during handshaking.
appName := "authNetworkErrorTestBackground"

mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"saslContinue"},
CloseConnection: true,
AppName: appName,
},
})

clientOpts := options.Client().
SetAppName(appName).
SetMinPoolSize(5).
SetPoolMonitor(poolMonitor)
mt.ResetClient(clientOpts)
clearPoolChan()

time.Sleep(200 * time.Millisecond)
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})
mt.Run("foreground", func(mt *mtest.T) {
// Assert that the pool is cleared when a connection created by an application thread connection
// checkout encounters a non-timeout network error during handshaking.
appName := "authNetworkErrorTestForeground"

_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"saslContinue"},
CloseConnection: true,
AppName: appName,
},
})

clientOpts := options.Client().
SetAppName(appName).
SetPoolMonitor(poolMonitor)
mt.ResetClient(clientOpts)
clearPoolChan()

_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})
})
})
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
clearPoolChan()
})
mt.RunOpts("after handshake completes", baseMtOpts(), func(mt *mtest.T) {
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
clearPoolChan()
mt.SetFailPoint(mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"insert"},
CloseConnection: true,
},
})

_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
assert.Nil(mt, err, "InsertOne error: %v", err)
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
assert.NotNil(mt, err, "expected InsertOne error, got nil")
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
})
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
clearPoolChan()

_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
assert.Nil(mt, err, "InsertOne error: %v", err)

filter := bson.M{
"$where": "function() { sleep(1000); return false; }",
}
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
defer cancel()
_, err = mt.Coll.Find(timeoutCtx, filter)
assert.NotNil(mt, err, "expected Find error, got %v", err)
filter := bson.M{
"$where": "function() { sleep(1000); return false; }",
}
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
defer cancel()
_, err = mt.Coll.Find(timeoutCtx, filter)
assert.NotNil(mt, err, "expected Find error, got %v", err)

assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
})
})
})
}
6 changes: 5 additions & 1 deletion mongo/integration/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func TestSessions(t *testing.T) {
CreateClient(false)
mt := mtest.New(t, mtOpts)

clusterTimeOpts := mtest.NewOptions().ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
clusterTimeOpts := mtest.NewOptions().
ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
ClientType(mtest.Pinned).
CreateClient(false)
mt.RunOpts("cluster time", clusterTimeOpts, func(mt *mtest.T) {
// $clusterTime included in commands
Expand Down
30 changes: 15 additions & 15 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
return c, nil
}

func (c *connection) processInitializationError(err error) {
atomic.StoreInt32(&c.connected, disconnected)
if c.nc != nil {
_ = c.nc.Close()
}

c.connectErr = ConnectionError{Wrapped: err, init: true}
if c.config.errorHandlingCallback != nil {
c.config.errorHandlingCallback(c.connectErr)
}
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform
// initialization handshakes.
func (c *connection) connect(ctx context.Context) {
Expand All @@ -101,8 +113,7 @@ func (c *connection) connect(ctx context.Context) {
var tempNc net.Conn
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
if err != nil {
atomic.StoreInt32(&c.connected, disconnected)
c.connectErr = ConnectionError{Wrapped: err, init: true}
c.processInitializationError(err)
return
}
c.nc = tempNc
Expand All @@ -114,11 +125,7 @@ func (c *connection) connect(ctx context.Context) {
// error cases.
tlsNc, err := configureTLS(ctx, c.nc, c.addr, tlsConfig)
if err != nil {
if c.nc != nil {
_ = c.nc.Close()
}
atomic.StoreInt32(&c.connected, disconnected)
c.connectErr = ConnectionError{Wrapped: err, init: true}
c.processInitializationError(err)
return
}
c.nc = tlsNc
Expand All @@ -138,17 +145,10 @@ func (c *connection) connect(ctx context.Context) {
err = handshaker.FinishHandshake(ctx, handshakeConn)
}
if err != nil {
if c.nc != nil {
_ = c.nc.Close()
}
atomic.StoreInt32(&c.connected, disconnected)
c.connectErr = ConnectionError{Wrapped: err, init: true}
c.processInitializationError(err)
return
}

if c.config.descCallback != nil {
c.config.descCallback(c.desc)
}
if len(c.desc.Compression) > 0 {
clientMethodLoop:
for _, method := range c.config.compressors {
Expand Down
43 changes: 21 additions & 22 deletions x/mongo/driver/topology/connection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)

// Dialer is used to make network connections.
Expand Down Expand Up @@ -36,20 +35,20 @@ var DefaultDialer Dialer = &net.Dialer{}
type Handshaker = driver.Handshaker

type connectionConfig struct {
appName string
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
lifeTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
compressors []string
zlibLevel *int
zstdLevel *int
descCallback func(description.Server)
appName string
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
lifeTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
compressors []string
zlibLevel *int
zstdLevel *int
errorHandlingCallback func(error)
}

func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
Expand All @@ -73,16 +72,16 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
return cfg, nil
}

func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
return append(opts, ConnectionOption(func(c *connectionConfig) error {
c.descCallback = callback
return nil
}))
}

// ConnectionOption is used to configure a connection.
type ConnectionOption func(*connectionConfig) error

func withErrorHandlingCallback(fn func(error)) ConnectionOption {
return func(c *connectionConfig) error {
c.errorHandlingCallback = fn
return nil
}
}

// WithCompressors sets the compressors that can be used for communication.
func WithCompressors(fn func([]string) []string) ConnectionOption {
return func(c *connectionConfig) error {
Expand Down
44 changes: 23 additions & 21 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,32 +96,34 @@ func TestConnection(t *testing.T) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("calls description callback", func(t *testing.T) {
want := description.Server{Addr: address.Address("1.2.3.4:56789")}
var got description.Server
t.Run("calls error callback", func(t *testing.T) {
handshakerError := errors.New("handshaker error")
var got error

conn, err := newConnection(context.Background(), address.Address(""),
withServerDescriptionCallback(func(desc description.Server) { got = desc },
WithHandshaker(func(Handshaker) Handshaker {
return &testHandshaker{
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
return want, nil
},
}
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
)...,
WithHandshaker(func(Handshaker) Handshaker {
return &testHandshaker{
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
return description.Server{}, handshakerError
},
}
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
withErrorHandlingCallback(func(err error) {
got = err
}),
)
noerr(t, err)
conn.connect(context.Background())

var want error = ConnectionError{Wrapped: handshakerError}
err = conn.wait()
noerr(t, err)
if !cmp.Equal(got, want) {
t.Errorf("Server descriptions do not match. got %v; want %v", got, want)
}
assert.NotNil(t, err, "expected connect error %v, got nil", want)
assert.Equal(t, want, got, "expected error %v, got %v", want, got)
})
})
t.Run("writeWireMessage", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 0eb881e

Please sign in to comment.