diff --git a/internal/test/compilecheck/go.mod b/internal/test/compilecheck/go.mod index 69d192022a..cc09124838 100644 --- a/internal/test/compilecheck/go.mod +++ b/internal/test/compilecheck/go.mod @@ -9,14 +9,14 @@ replace go.mongodb.org/mongo-driver => ../../../ require go.mongodb.org/mongo-driver v1.11.7 require ( - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.13.6 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/montanaflynn/stats v0.7.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/internal/test/compilecheck/go.sum b/internal/test/compilecheck/go.sum index fe79e66209..802402a881 100644 --- a/internal/test/compilecheck/go.sum +++ b/internal/test/compilecheck/go.sum @@ -1,11 +1,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= @@ -17,16 +17,16 @@ github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7Jul github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -44,4 +44,3 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/mongo/client.go b/mongo/client.go index 280749c7dd..4266412aab 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -209,10 +209,6 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt.SetMaxPoolSize(defaultMaxPoolSize) } - if err != nil { - return nil, err - } - cfg, err := topology.NewConfig(clientOpt, client.clock) if err != nil { return nil, err diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 0139d273da..0e478537b4 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -526,31 +526,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false // the tick should wait for 100ms in this case } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") }) // Test that if the minimum RTT is greater than the remaining timeout for an operation, the @@ -574,31 +566,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") // Once we've waited for the minimum RTT for the single server to be >250ms, run a bunch of // Ping operations with a timeout of 250ms and expect that they return errors. @@ -625,31 +609,23 @@ func TestClient(t *testing.T) { // Assert that RTT90s are eventually >300ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >300ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 300*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's RTT90s to be >300ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().P90() <= 300*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the RTT90s are eventually >300ms") }) // Test that if Timeout is set and the RTT90 is greater than the remaining timeout for an operation, the @@ -676,31 +652,23 @@ func TestClient(t *testing.T) { // Assert that RTT90s are eventually >275ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >275ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 275*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's RTT90s to be >275ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().P90() <= 275*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the RTT90s are eventually >275ms") // Once we've waited for the RTT90 for the servers to be >275ms, run 10 Ping operations // with a timeout of 275ms and expect that they return timeout errors. diff --git a/mongo/integration/csot_prose_test.go b/mongo/integration/csot_prose_test.go index 4f9f112b3f..c8ddfd68df 100644 --- a/mongo/integration/csot_prose_test.go +++ b/mongo/integration/csot_prose_test.go @@ -89,13 +89,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored if timeoutMS is not set", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=100&serverSelectionTimeoutMS=200") @@ -103,13 +108,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=200&serverSelectionTimeoutMS=100") @@ -117,13 +127,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if it's lower than timeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=100") @@ -131,13 +146,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if timeoutMS=0", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) }) } diff --git a/mongo/integration/sdam_error_handling_test.go b/mongo/integration/sdam_error_handling_test.go index 58cac9ccdd..4a2baf542d 100644 --- a/mongo/integration/sdam_error_handling_test.go +++ b/mongo/integration/sdam_error_handling_test.go @@ -85,23 +85,13 @@ func TestSDAMErrorHandling(t *testing.T) { assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) + // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) { @@ -131,22 +121,11 @@ func TestSDAMErrorHandling(t *testing.T) { SetMinPoolSize(5)) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.Run("foreground", func(mt *mtest.T) { @@ -175,22 +154,11 @@ func TestSDAMErrorHandling(t *testing.T) { assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) }) }) diff --git a/mongo/integration/sdam_prose_test.go b/mongo/integration/sdam_prose_test.go index f91bab1176..615c77569b 100644 --- a/mongo/integration/sdam_prose_test.go +++ b/mongo/integration/sdam_prose_test.go @@ -124,28 +124,23 @@ func TestSDAMProse(t *testing.T) { AppName: "streamingRttTest", }, }) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: + callback := func() bool { + // We don't know which server received the failpoint command, so we wait until any of the server + // RTTs cross the threshold. + for _, serverDesc := range testTopology.Description().Servers { + if serverDesc.AverageRTT > 250*time.Millisecond { + return true } - - // We don't know which server received the failpoint command, so we wait until any of the server - // RTTs cross the threshold. - for _, serverDesc := range testTopology.Description().Servers { - if serverDesc.AverageRTT > 250*time.Millisecond { - return - } - } - - // The next update will be in ~500ms. - time.Sleep(500 * time.Millisecond) } + + // The next update will be in ~500ms. + return false } - assert.Soon(t, callback, defaultCallbackTimeout) + assert.Eventually(t, + callback, + defaultCallbackTimeout, + 500*time.Millisecond, + "expected average rtt heartbeats at least within every 500 ms period") }) }) diff --git a/mongo/integration/unified_runner_events_helper_test.go b/mongo/integration/unified_runner_events_helper_test.go index 780be40de9..5afc510e14 100644 --- a/mongo/integration/unified_runner_events_helper_test.go +++ b/mongo/integration/unified_runner_events_helper_test.go @@ -87,31 +87,23 @@ func waitForEvent(mt *mtest.T, test *testCase, op *operation) { eventType := op.Arguments.Lookup("event").StringValue() expectedCount := int(op.Arguments.Lookup("count").Int32()) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - var count int - // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. - if eventType == "ServerMarkedUnknownEvent" { - count = test.monitor.getServerMarkedUnknownCount() - } else { - count = test.monitor.getPoolEventCount(eventType) - } - - if count >= expectedCount { - return - } - time.Sleep(100 * time.Millisecond) + callback := func() bool { + var count int + // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. + if eventType == "ServerMarkedUnknownEvent" { + count = test.monitor.getServerMarkedUnknownCount() + } else { + count = test.monitor.getPoolEventCount(eventType) } + + return count >= expectedCount } - assert.Soon(mt, callback, defaultCallbackTimeout) + assert.Eventually(mt, + callback, + defaultCallbackTimeout, + 100*time.Millisecond, + "expected spec tests to only wait for Server Marked Unknown SDAM events") } func assertEventCount(mt *mtest.T, testCase *testCase, op *operation) { @@ -134,23 +126,16 @@ func recordPrimary(mt *mtest.T, testCase *testCase) { } func waitForPrimaryChange(mt *mtest.T, testCase *testCase, op *operation) { - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - if getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary { - return - } - } + callback := func() bool { + return getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary } timeout := convertValueToMilliseconds(mt, op.Arguments.Lookup("timeoutMS")) - assert.Soon(mt, callback, timeout) + assert.Eventually(mt, + callback, + timeout, + 100*time.Millisecond, + "expected primary address to be different within the timeout period") } // getPrimaryAddress returns the address of the current primary. If failFast is true, the server selection fast path diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index af7ce98b0c..544053b973 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -399,19 +399,23 @@ func TestConvenientTransactions(t *testing.T) { // Insert a document within a session and manually cancel context before // "commitTransaction" can be sent. - callback := func(ctx context.Context) { - transactionCtx, cancel := context.WithCancel(ctx) - + callback := func() bool { + transactionCtx, cancel := context.WithCancel(context.Background()) _, _ = sess.WithTransaction(transactionCtx, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.M{"x": 1}) - assert.Nil(t, err, "InsertOne error: %v", err) + assert.NoError(t, err, "InsertOne error: %v", err) cancel() return nil, nil }) + return true } // Assert that transaction is canceled within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to be canceled within 500ms") // Assert that AbortTransaction was started once and succeeded. assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted)) @@ -459,19 +463,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context with short timeout. - withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond) + withTransactionContext, cancel := context.WithTimeout(context.Background(), time.Nanosecond) defer cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("canceled context before callback does not retry", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -489,19 +498,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context and cancel it immediately. - withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second) + withTransactionContext, cancel := context.WithTimeout(context.Background(), 2*time.Second) cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("slow operation in callback retries", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -540,8 +554,8 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { - _, err = sess.WithTransaction(ctx, func(ctx SessionContext) (interface{}, error) { + callback := func() bool { + _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) { // Set a timeout of 300ms to cause a timeout on first insertOne // and force a retry. c, cancel := context.WithTimeout(ctx, 300*time.Millisecond) @@ -550,11 +564,17 @@ func TestConvenientTransactions(t *testing.T) { _, err := coll.InsertOne(c, bson.D{{}}) return nil, err }) - assert.Nil(t, err, "WithTransaction error: %v", err) + assert.NoError(t, err, "WithTransaction error: %v", err) + return true } // Assert that transaction passes within 2 seconds. - assert.Soon(t, callback, 2*time.Second) + assert.Eventually(t, + callback, + withTransactionTimeout, + time.Millisecond, + "expected transaction to be passed within 2s") + }) } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index dc774b469b..946f74d8f2 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -236,13 +236,18 @@ func TestConnection(t *testing.T) { conn := newConnection("", connOpts...) var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) + callback := func() bool { + connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) defer cancel() connectErr = conn.connect(connectCtx) + return true } - assert.Soon(t, callback, tc.maxConnectTime) + assert.Eventually(t, + callback, + tc.maxConnectTime, + time.Millisecond, + "expected timeout to apply to socket establishment after maximum connect time") ce, ok := connectErr.(ConnectionError) assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) @@ -271,13 +276,18 @@ func TestConnection(t *testing.T) { conn := newConnection(address.Address(l.Addr().String()), connOpts...) var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) + callback := func() bool { + connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) defer cancel() connectErr = conn.connect(connectCtx) + return true } - assert.Soon(t, callback, tc.maxConnectTime) + assert.Eventually(t, + callback, + tc.maxConnectTime, + time.Millisecond, + "expected timeout to apply to TLS handshake after maximum connect time") ce, ok := connectErr.(ConnectionError) assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c09ef9731c..c7dc7336e9 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -46,15 +46,21 @@ func TestTopologyErrors(t *testing.T) { assert.Nil(t, err, "error creating topology: %v", err) var serverSelectionErr error - callback := func(ctx context.Context) { - selectServerCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + callback := func() bool { + selectServerCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() state := newServerSelectionState(selectNone, make(<-chan time.Time)) subCh := make(<-chan description.Topology) _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, state) + return true } - assert.Soon(t, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected context deadline to fail within 150ms") + assert.True(t, errors.Is(serverSelectionErr, context.DeadlineExceeded), "expected %v, received %v", context.DeadlineExceeded, serverSelectionErr) })