From ea15f7d7a26b7e2cddf78aa92e4229cf95005b8f Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 27 Jun 2024 15:11:40 -0600 Subject: [PATCH] GODRIVER-2348 Make CSOT feature-gated behavior the default (#1515) Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> Co-authored-by: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> --- internal/csot/csot.go | 78 ++++- internal/csot/csot_test.go | 249 ++++++++++++++ internal/docexamples/examples.go | 2 - .../client_side_encryption_prose_test.go | 20 +- internal/integration/client_test.go | 2 +- internal/integration/collection_test.go | 10 +- internal/integration/crud_helpers_test.go | 4 - internal/integration/errors_test.go | 185 +++++++--- internal/integration/index_view_test.go | 13 +- internal/integration/json_helpers_test.go | 12 +- internal/integration/mtest/mongotest.go | 2 - .../integration/mtest/opmsg_deployment.go | 7 + .../integration/sdam_error_handling_test.go | 7 +- internal/integration/sdam_prose_test.go | 1 + internal/integration/unified/client_entity.go | 7 +- .../unified/collection_operation_execution.go | 73 +++- .../integration/unified/common_options.go | 9 +- .../unified/database_operation_execution.go | 12 +- .../gridfs_bucket_operation_execution.go | 8 +- internal/integration/unified/operation.go | 11 +- .../integration/unified/session_options.go | 27 +- .../unified/testrunner_operation.go | 1 - .../unified/unified_spec_runner.go | 7 + .../ptrutil/ptr.go | 12 +- mongo/batch_cursor.go | 4 +- mongo/change_stream.go | 107 ++++-- mongo/change_stream_deployment.go | 8 + mongo/change_stream_test.go | 97 +++++- mongo/client.go | 15 +- mongo/client_test.go | 9 + mongo/collection.go | 48 +-- mongo/crud_examples_test.go | 50 +-- mongo/cursor.go | 6 +- mongo/cursor_test.go | 6 +- mongo/database.go | 5 + mongo/errors.go | 1 - mongo/gridfs_bucket.go | 97 +++--- mongo/index_view.go | 23 +- mongo/options/aggregateoptions.go | 18 - mongo/options/clientoptions.go | 56 +-- mongo/options/clientoptions_test.go | 11 - mongo/options/countoptions.go | 20 -- mongo/options/distinctoptions.go | 20 -- mongo/options/estimatedcountoptions.go | 20 -- mongo/options/findoptions.go | 90 ----- mongo/options/gridfsoptions.go | 20 -- mongo/options/indexoptions.go | 60 +--- mongo/options/sessionoptions.go | 21 -- mongo/options/transactionoptions.go | 24 -- mongo/read_write_concern_spec_test.go | 22 +- mongo/session.go | 6 +- mongo/writeconcern/writeconcern.go | 20 -- mongo/writeconcern/writeconcern_test.go | 106 ------ .../retryability-legacy-timeouts.json | 102 +++--- .../retryability-legacy-timeouts.yml | 106 +++--- .../convenient-transactions/commit-retry.json | 2 + .../convenient-transactions/commit-retry.yml | 2 + .../connection-string/write-concern.json | 4 + .../connection-string/write-concern.yml | 4 + .../document/write-concern.json | 4 + .../document/write-concern.yml | 3 + .../transactions/legacy/error-labels.json | 3 + testdata/transactions/legacy/error-labels.yml | 3 + .../transactions/legacy/retryable-commit.json | 1 + .../transactions/legacy/retryable-commit.yml | 1 + .../legacy/transaction-options.json | 3 + .../legacy/transaction-options.yml | 3 + x/mongo/driver/batch_cursor.go | 39 ++- x/mongo/driver/batch_cursor_test.go | 41 --- x/mongo/driver/connstring/connstring.go | 27 -- .../driver/connstring/connstring_spec_test.go | 7 +- x/mongo/driver/connstring/connstring_test.go | 27 -- x/mongo/driver/driver.go | 18 + x/mongo/driver/errors.go | 2 - x/mongo/driver/integration/aggregate_test.go | 8 +- x/mongo/driver/operation.go | 149 +++++--- x/mongo/driver/operation/aggregate.go | 12 - .../driver/operation/commit_transaction.go | 13 - x/mongo/driver/operation/count.go | 12 - x/mongo/driver/operation/create_indexes.go | 12 - x/mongo/driver/operation/distinct.go | 12 - x/mongo/driver/operation/drop_indexes.go | 12 - x/mongo/driver/operation/find.go | 12 - x/mongo/driver/operation/find_and_modify.go | 12 - x/mongo/driver/operation/hello.go | 16 +- x/mongo/driver/operation/list_indexes.go | 12 - x/mongo/driver/operation_test.go | 123 ++++--- x/mongo/driver/session/client_session.go | 52 ++- x/mongo/driver/session/options.go | 7 - x/mongo/driver/topology/CMAP_spec_test.go | 4 +- x/mongo/driver/topology/connection.go | 153 ++------- x/mongo/driver/topology/connection_options.go | 26 -- x/mongo/driver/topology/connection_test.go | 177 +--------- x/mongo/driver/topology/context_listener.go | 91 +++++ x/mongo/driver/topology/pool.go | 32 +- x/mongo/driver/topology/pool_test.go | 105 ++++-- x/mongo/driver/topology/rtt_monitor.go | 17 +- x/mongo/driver/topology/rtt_monitor_test.go | 10 +- x/mongo/driver/topology/server.go | 321 +++++++++++------- x/mongo/driver/topology/server_options.go | 18 +- x/mongo/driver/topology/server_test.go | 219 ++++++++---- x/mongo/driver/topology/topology.go | 67 ++-- .../driver/topology/topology_errors_test.go | 3 +- x/mongo/driver/topology/topology_options.go | 30 +- .../driver/topology/topology_options_test.go | 2 +- x/mongo/driver/topology/topology_test.go | 154 +-------- 106 files changed, 1977 insertions(+), 2067 deletions(-) create mode 100644 internal/csot/csot_test.go rename x/mongo/driver/topology/cancellation_listener.go => internal/ptrutil/ptr.go (56%) create mode 100644 x/mongo/driver/topology/context_listener.go diff --git a/internal/csot/csot.go b/internal/csot/csot.go index 678252c51a..1e7b1901ea 100644 --- a/internal/csot/csot.go +++ b/internal/csot/csot.go @@ -11,26 +11,74 @@ import ( "time" ) -type timeoutKey struct{} +type clientLevel struct{} -// MakeTimeoutContext returns a new context with Client-Side Operation Timeout (CSOT) feature-gated behavior -// and a Timeout set to the passed in Duration. Setting a Timeout on a single operation is not supported in -// public API. -// -// TODO(GODRIVER-2348) We may be able to remove this function once CSOT feature-gated behavior becomes the -// TODO default behavior. -func MakeTimeoutContext(ctx context.Context, to time.Duration) (context.Context, context.CancelFunc) { - // Only use the passed in Duration as a timeout on the Context if it - // is non-zero. - cancelFunc := func() {} - if to != 0 { - ctx, cancelFunc = context.WithTimeout(ctx, to) +func isClientLevel(ctx context.Context) bool { + val := ctx.Value(clientLevel{}) + if val == nil { + return false } - return context.WithValue(ctx, timeoutKey{}, true), cancelFunc + + return val.(bool) } +// IsTimeoutContext checks if the provided context has been assigned a deadline +// or has unlimited retries. func IsTimeoutContext(ctx context.Context) bool { - return ctx.Value(timeoutKey{}) != nil + _, ok := ctx.Deadline() + + return ok || isClientLevel(ctx) +} + +// WithTimeout will set the given timeout on the context, if no deadline has +// already been set. +// +// This function assumes that the timeout field is static, given that the +// timeout should be sourced from the client. Therefore, once a timeout function +// parameter has been applied to the context, it will remain for the lifetime +// of the context. +func WithTimeout(parent context.Context, timeout *time.Duration) (context.Context, context.CancelFunc) { + cancel := func() {} + + if timeout == nil || IsTimeoutContext(parent) { + // In the following conditions, do nothing: + // 1. The parent already has a deadline + // 2. The parent does not have a deadline, but a client-level timeout has + // been applied. + // 3. The parent does not have a deadline, there is not client-level + // timeout, and the timeout parameter DNE. + return parent, cancel + } + + // If a client-level timeout has not been applied, then apply it. + parent = context.WithValue(parent, clientLevel{}, true) + + dur := *timeout + + if dur == 0 { + // If the parent does not have a deadline and the timeout is zero, then + // do nothing. + return parent, cancel + } + + // If the parent does not have a dealine and the timeout is non-zero, then + // apply the timeout. + return context.WithTimeout(parent, dur) +} + +// WithServerSelectionTimeout creates a context with a timeout that is the +// minimum of serverSelectionTimeoutMS and context deadline. The usage of +// non-positive values for serverSelectionTimeoutMS are an anti-pattern and are +// not considered in this calculation. +func WithServerSelectionTimeout( + parent context.Context, + serverSelectionTimeout time.Duration, +) (context.Context, context.CancelFunc) { + if serverSelectionTimeout <= 0 { + return parent, func() {} + } + + return context.WithTimeout(parent, serverSelectionTimeout) } // ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all diff --git a/internal/csot/csot_test.go b/internal/csot/csot_test.go new file mode 100644 index 0000000000..5b79f6994a --- /dev/null +++ b/internal/csot/csot_test.go @@ -0,0 +1,249 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package csot + +import ( + "context" + "testing" + "time" + + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/ptrutil" +) + +func newTestContext(t *testing.T, timeout time.Duration) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + return ctx +} + +func TestWithServerSelectionTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + parent context.Context + serverSelectionTimeout time.Duration + wantTimeout time.Duration + wantOk bool + }{ + { + name: "no context deadine and ssto is zero", + parent: context.Background(), + serverSelectionTimeout: 0, + wantTimeout: 0, + wantOk: false, + }, + { + name: "no context deadline and ssto is positive", + parent: context.Background(), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "no context deadline and ssto is negative", + parent: context.Background(), + serverSelectionTimeout: -1, + wantTimeout: 0, + wantOk: false, + }, + { + name: "context deadline is zero and ssto is positive", + parent: newTestContext(t, 0), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is zero and ssto is negative", + parent: newTestContext(t, 0), + serverSelectionTimeout: -1, + wantTimeout: 0, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is zero", + parent: newTestContext(t, -1), + serverSelectionTimeout: 0, + wantTimeout: -1, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is positive", + parent: newTestContext(t, -1), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is negative and ssto is negative", + parent: newTestContext(t, -1), + serverSelectionTimeout: -1, + wantTimeout: -1, + wantOk: true, + }, + { + name: "context deadline is positive and ssto is zero", + parent: newTestContext(t, 1), + serverSelectionTimeout: 0, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is positive and equal to ssto", + parent: newTestContext(t, 1), + serverSelectionTimeout: 1, + wantTimeout: 1, + wantOk: true, + }, + { + name: "context deadline is positive lt ssto", + parent: newTestContext(t, 1), + serverSelectionTimeout: 2, + wantTimeout: 2, + wantOk: true, + }, + { + name: "context deadline is positive gt ssto", + parent: newTestContext(t, 2), + serverSelectionTimeout: 1, + wantTimeout: 2, + wantOk: true, + }, + { + name: "context deadline is positive and ssto is negative", + parent: newTestContext(t, -1), + serverSelectionTimeout: -1, + wantTimeout: 1, + wantOk: true, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := WithServerSelectionTimeout(test.parent, test.serverSelectionTimeout) + t.Cleanup(cancel) + + deadline, gotOk := ctx.Deadline() + assert.Equal(t, test.wantOk, gotOk) + + if gotOk { + delta := time.Until(deadline) - test.wantTimeout + tolerance := 10 * time.Millisecond + + assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance) + assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance) + } + }) + } +} + +func TestWithTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + parent context.Context + timeout *time.Duration + wantTimeout time.Duration + wantDeadline bool + wantValues []interface{} + }{ + { + name: "deadline set with non-zero timeout", + parent: newTestContext(t, 1), + timeout: ptrutil.Ptr(time.Duration(2)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline set with zero timeout", + parent: newTestContext(t, 1), + timeout: ptrutil.Ptr(time.Duration(0)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline set with nil timeout", + parent: newTestContext(t, 1), + timeout: nil, + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline unset with non-zero timeout", + parent: context.Background(), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 1, + wantDeadline: true, + wantValues: []interface{}{}, + }, + { + name: "deadline unset with zero timeout", + parent: context.Background(), + timeout: ptrutil.Ptr(time.Duration(0)), + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{clientLevel{}}, + }, + { + name: "deadline unset with nil timeout", + parent: context.Background(), + timeout: nil, + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{}, + }, + { + // If "clientLevel" has been set, but a new timeout is applied + // to the context, then the constructed context should retain the old + // timeout. To simplify the code, we assume the first timeout is static. + name: "deadline unset with non-zero timeout at clientLevel", + parent: context.WithValue(context.Background(), clientLevel{}, true), + timeout: ptrutil.Ptr(time.Duration(1)), + wantTimeout: 0, + wantDeadline: false, + wantValues: []interface{}{}, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := WithTimeout(test.parent, test.timeout) + t.Cleanup(cancel) + + deadline, gotDeadline := ctx.Deadline() + assert.Equal(t, test.wantDeadline, gotDeadline) + + if gotDeadline { + delta := time.Until(deadline) - test.wantTimeout + tolerance := 10 * time.Millisecond + + assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance) + assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance) + } + + for _, wantValue := range test.wantValues { + assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue) + } + }) + } + +} diff --git a/internal/docexamples/examples.go b/internal/docexamples/examples.go index 2f95bce65f..71064947a9 100644 --- a/internal/docexamples/examples.go +++ b/internal/docexamples/examples.go @@ -1978,7 +1978,6 @@ func WithTransactionExample(ctx context.Context) error { // Prereq: Create collections. wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) fooColl := client.Database("mydb1").Collection("foo", wcMajorityCollectionOpts) barColl := client.Database("mydb1").Collection("bar", wcMajorityCollectionOpts) @@ -2559,7 +2558,6 @@ func CausalConsistencyExamples(client *mongo.Client) error { rc := readconcern.Majority() wc := writeconcern.Majority() - wc.WTimeout = 1000 // Use a causally-consistent session to run some operations opts := options.Session().SetDefaultReadConcern(rc).SetDefaultWriteConcern(wc) session1, err := client.StartSession(opts) diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index d5c829ea43..af6f4af7b6 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -53,16 +53,6 @@ const ( maxBsonObjSize = 16777216 // max bytes in BSON object ) -func containsSubstring(possibleSubstrings []string, str string) bool { - for _, possibleSubstring := range possibleSubstrings { - if strings.Contains(str, possibleSubstring) { - return true - } - } - - return false -} - func TestClientSideEncryptionProse(t *testing.T) { t.Parallel() @@ -150,7 +140,6 @@ func TestClientSideEncryptionProse(t *testing.T) { // Insert the copied key document into keyvault.datakeys with majority write concern. wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) wcmColl := cse.kvClient.Database(kvDatabase).Collection(dkCollection, wcMajorityCollectionOpts) _, err = wcmColl.InsertOne(context.Background(), alteredKeydoc) @@ -1001,7 +990,7 @@ func TestClientSideEncryptionProse(t *testing.T) { if len(tc.errorSubstring) > 0 { assert.NotNil(mt, err, "expected error, got nil") - assert.True(t, containsSubstring(tc.errorSubstring, err.Error()), + assert.True(t, containsPattern(tc.errorSubstring, err.Error()), "expected tc.errorSubstring=%v to contain %v, but it didn't", tc.errorSubstring, err.Error()) return @@ -1031,7 +1020,7 @@ func TestClientSideEncryptionProse(t *testing.T) { _, err = invalidClientEncryption.CreateDataKey(context.Background(), tc.provider, invalidKeyOpts) assert.NotNil(mt, err, "expected CreateDataKey error, got nil") - assert.True(t, containsSubstring(tc.invalidClientEncryptionErrorSubstring, err.Error()), + assert.True(t, containsPattern(tc.invalidClientEncryptionErrorSubstring, err.Error()), "expected tc.invalidClientEncryptionErrorSubstring=%v to contain %v, but it didn't", tc.invalidClientEncryptionErrorSubstring, err.Error()) }) @@ -1635,7 +1624,7 @@ func TestClientSideEncryptionProse(t *testing.T) { "x509: certificate is not authorized to sign other certificates", // All others } - assert.True(t, containsSubstring(possibleErrors, err.Error()), + assert.True(t, containsPattern(possibleErrors, err.Error()), "expected possibleErrors=%v to contain %v, but it didn't", possibleErrors, err.Error()) @@ -1896,7 +1885,6 @@ func TestClientSideEncryptionProse(t *testing.T) { } wcMajority := writeconcern.Majority() - wcMajority.WTimeout = 1 * time.Second wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) wcmColl := cse.kvClient.Database(kvDatabase).Collection(dkCollection, wcMajorityCollectionOpts) _, err = wcmColl.Indexes().CreateOne(context.Background(), keyVaultIndex) @@ -2254,7 +2242,7 @@ func TestClientSideEncryptionProse(t *testing.T) { "Client.Timeout or context cancellation while reading body", // > 1.20 on all OS } - assert.True(t, containsSubstring(possibleErrors, err.Error()), + assert.True(t, containsPattern(possibleErrors, err.Error()), "expected possibleErrors=%v to contain %v, but it didn't", possibleErrors, err.Error()) }) diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 4d8e633590..822b517cd6 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -324,7 +324,7 @@ func TestClient(t *testing.T) { // apply the correct URI. invalidClientOpts := options.Client(). SetServerSelectionTimeout(100 * time.Millisecond).SetHosts([]string{"invalid:123"}). - SetConnectTimeout(500 * time.Millisecond).SetSocketTimeout(500 * time.Millisecond) + SetConnectTimeout(500 * time.Millisecond).SetTimeout(500 * time.Millisecond) integtest.AddTestServerAPIVersion(invalidClientOpts) client, err := mongo.Connect(invalidClientOpts) assert.Nil(mt, err, "Connect error: %v", err) diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index f520e16c89..dd1eb51e72 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -11,7 +11,6 @@ import ( "errors" "strings" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" @@ -36,8 +35,7 @@ var ( // for various operations. It includes a timeout because legacy servers will wait for all W nodes to respond, // causing tests to hang. impossibleWc = &writeconcern.WriteConcern{ - W: 30, - WTimeout: time.Second, + W: 30, } ) @@ -862,7 +860,7 @@ func TestCollection(t *testing.T) { count int64 }{ {"no options", nil, 5}, - {"options", options.EstimatedDocumentCount().SetMaxTime(1 * time.Second), 5}, + {"options", options.EstimatedDocumentCount().SetComment("1"), 5}, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { @@ -884,7 +882,7 @@ func TestCollection(t *testing.T) { }{ {"no options", bson.D{}, nil, all}, {"filter", bson.D{{"x", bson.D{{"$gt", 2}}}}, nil, all[2:]}, - {"options", bson.D{}, options.Distinct().SetMaxTime(5000000000), all}, + {"options", bson.D{}, options.Distinct().SetComment("1"), all}, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { @@ -1166,7 +1164,6 @@ func TestCollection(t *testing.T) { SetComment(expectedComment). SetHint(indexName). SetMax(bson.D{{"x", int32(5)}}). - SetMaxTime(1 * time.Second). SetMin(bson.D{{"x", int32(0)}}). SetProjection(bson.D{{"x", int32(1)}}). SetReturnKey(false). @@ -1188,7 +1185,6 @@ func TestCollection(t *testing.T) { AppendString("comment", expectedComment). AppendString("hint", indexName). StartDocument("max").AppendInt32("x", 5).FinishDocument(). - AppendInt32("maxTimeMS", 1000). StartDocument("min").AppendInt32("x", 0).FinishDocument(). StartDocument("projection").AppendInt32("x", 1).FinishDocument(). AppendBoolean("returnKey", false). diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index 355f934add..108515b11b 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -174,8 +174,6 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess *mongo.Session, args bso opts.SetBatchSize(val.Int32()) case "collation": opts.SetCollation(createCollation(mt, val.Document())) - case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "allowDiskUse": opts.SetAllowDiskUse(val.Boolean()) case "session": @@ -348,8 +346,6 @@ func setFindModifiers(modifiersDoc bson.Raw, opts *options.FindOptions) { opts.SetHint(val.Document()) case "$max": opts.SetMax(val.Document()) - case "$maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "$min": opts.SetMin(val.Document()) case "$returnKey": diff --git a/internal/integration/errors_test.go b/internal/integration/errors_test.go index 8c6c0fb812..b3a4094c15 100644 --- a/internal/integration/errors_test.go +++ b/internal/integration/errors_test.go @@ -15,6 +15,7 @@ import ( "fmt" "io" "net" + "regexp" "testing" "time" @@ -46,6 +47,17 @@ func (n netErr) Temporary() bool { var _ net.Error = (*netErr)(nil) +func containsPattern(patterns []string, str string) bool { + for _, pattern := range patterns { + re := regexp.MustCompile(pattern) + if re.MatchString(str) { + return true + } + } + + return false +} + func TestErrors(t *testing.T) { mt := mtest.New(t, noClientOpts) @@ -96,39 +108,22 @@ func TestErrors(t *testing.T) { } timeoutCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - _, err = mt.Coll.Find(timeoutCtx, filter) - evt := mt.GetStartedEvent() - assert.Equal(mt, "find", evt.CommandName, "expected command 'find', got %q", evt.CommandName) - assert.True(mt, errors.Is(err, context.DeadlineExceeded), - "errors.Is failure: expected error %v to be %v", err, context.DeadlineExceeded) - }) - - mt.Run("socketTimeoutMS timeouts return network errors", func(mt *mtest.T) { - _, err := mt.Coll.InsertOne(context.Background(), bson.D{{"x", 1}}) - assert.Nil(mt, err, "InsertOne error: %v", err) + _, err = mt.Coll.Find(timeoutCtx, filter) - // Reset the test client to have a 100ms socket timeout. We do this here rather than passing it in as a - // test option using mt.RunOpts because that could cause the collection creation or InsertOne to fail. - resetClientOpts := options.Client(). - SetSocketTimeout(100 * time.Millisecond) - mt.ResetClient(resetClientOpts) + assert.Error(mt, err) - mt.ClearEvents() - filter := bson.M{ - "$where": "function() { sleep(1000); return false; }", + errPatterns := []string{ + context.DeadlineExceeded.Error(), + `^\(MaxTimeMSExpired\) Executor error during find command.*:: caused by :: operation exceeded time limit$`, } - _, err = mt.Coll.Find(context.Background(), filter) + + assert.True(t, containsPattern(errPatterns, err.Error()), + "expected possibleErrors=%v to contain %v, but it didn't", + errPatterns, err.Error()) evt := mt.GetStartedEvent() assert.Equal(mt, "find", evt.CommandName, "expected command 'find', got %q", evt.CommandName) - - assert.False(mt, errors.Is(err, context.DeadlineExceeded), - "errors.Is failure: expected error %v to not be %v", err, context.DeadlineExceeded) - var netErr net.Error - ok := errors.As(err, &netErr) - assert.True(mt, ok, "errors.As failure: expected error %v to be a net.Error", err) - assert.True(mt, netErr.Timeout(), "expected error %v to be a network timeout", err) }) }) mt.Run("ServerError", func(mt *mtest.T) { @@ -505,26 +500,124 @@ func TestErrors(t *testing.T) { err error result bool }{ - {"context timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}, true}, - {"deadline would be exceeded", mongo.CommandError{ - 100, "", []string{"other"}, "blah", driver.ErrDeadlineWouldBeExceeded, nil}, true}, - {"server selection timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", topology.ErrServerSelectionTimeout, nil}, true}, - {"wait queue timeout", mongo.CommandError{ - 100, "", []string{"other"}, "blah", topology.WaitQueueTimeoutError{}, nil}, true}, - {"ServerError NetworkTimeoutError", mongo.CommandError{ - 100, "", []string{"NetworkTimeoutError"}, "blah", nil, nil}, true}, - {"ServerError ExceededTimeLimitError", mongo.CommandError{ - 100, "", []string{"ExceededTimeLimitError"}, "blah", nil, nil}, true}, - {"ServerError false", mongo.CommandError{ - 100, "", []string{"other"}, "blah", nil, nil}, false}, - {"net error true", mongo.CommandError{ - 100, "", []string{"other"}, "blah", netErr{true}, nil}, true}, - {"net error false", netErr{false}, false}, - {"wrapped error", fmt.Errorf("%w", mongo.CommandError{ - 100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}), true}, - {"other error", errors.New("foo"), false}, + { + name: "context timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "deadline would be exceeded", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: driver.ErrDeadlineWouldBeExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "server selection timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }, + result: true, + }, + { + name: "wait queue timeout", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: topology.WaitQueueTimeoutError{}, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError NetworkTimeoutError", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"NetworkTimeoutError"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError ExceededTimeLimitError", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"ExceededTimeLimitError"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: true, + }, + { + name: "ServerError false", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: nil, + Raw: nil, + }, + result: false, + }, + { + name: "net error true", + err: mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: netErr{true}, + Raw: nil, + }, + result: true, + }, + { + name: "net error false", + err: netErr{false}, + result: false, + }, + { + name: "wrapped error", + err: fmt.Errorf("%w", mongo.CommandError{ + Code: 100, + Message: "", + Labels: []string{"other"}, + Name: "blah", + Wrapped: context.DeadlineExceeded, + Raw: nil, + }), + result: true, + }, + { + name: "other error", + err: errors.New("foo"), + result: false, + }, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { diff --git a/internal/integration/index_view_test.go b/internal/integration/index_view_test.go index 5b4a46e42f..e3d1ae687c 100644 --- a/internal/integration/index_view_test.go +++ b/internal/integration/index_view_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "testing" - "time" "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/bson" @@ -554,16 +553,20 @@ func TestIndexView(t *testing.T) { assert.True(mt, cmp.Equal(specs, expectedSpecs), "expected specifications to match: %v", cmp.Diff(specs, expectedSpecs)) }) mt.RunOpts("options passed to listIndexes", mtest.NewOptions().MinServerVersion("3.0"), func(mt *mtest.T) { - opts := options.ListIndexes().SetMaxTime(100 * time.Millisecond) + opts := options.ListIndexes().SetBatchSize(1) _, err := mt.Coll.Indexes().ListSpecifications(context.Background(), opts) assert.Nil(mt, err, "ListSpecifications error: %v", err) evt := mt.GetStartedEvent() assert.Equal(mt, evt.CommandName, "listIndexes", "expected %q command to be sent, got %q", "listIndexes", evt.CommandName) - maxTimeMS, ok := evt.Command.Lookup("maxTimeMS").Int64OK() - assert.True(mt, ok, "expected command %v to contain %q field", evt.Command, "maxTimeMS") - assert.Equal(mt, int64(100), maxTimeMS, "expected maxTimeMS value to be 100, got %d", maxTimeMS) + + cursorDoc, ok := evt.Command.Lookup("cursor").DocumentOK() + assert.True(mt, ok, "expected command: %v to contain a cursor document", evt.Command) + + batchSize, ok := cursorDoc.Lookup("batchSize").Int32OK() + assert.True(mt, ok, "expected command %v to contain %q field", evt.Command, "batchSize") + assert.Equal(mt, int32(1), batchSize, "expected batchSize value to be 1, got %d", batchSize) }) }) mt.Run("drop one", func(mt *mtest.T) { diff --git a/internal/integration/json_helpers_test.go b/internal/integration/json_helpers_test.go index 24877da159..194c316413 100644 --- a/internal/integration/json_helpers_test.go +++ b/internal/integration/json_helpers_test.go @@ -111,9 +111,6 @@ func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions { case "serverSelectionTimeoutMS": sst := convertValueToMilliseconds(t, opt) clientOpts.SetServerSelectionTimeout(sst) - case "socketTimeoutMS": - st := convertValueToMilliseconds(t, opt) - clientOpts.SetSocketTimeout(st) case "minPoolSize": clientOpts.SetMinPoolSize(uint64(opt.AsInt64())) case "maxPoolSize": @@ -301,9 +298,6 @@ func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions { if txnOpts.WriteConcern != nil { sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern) } - if txnOpts.MaxCommitTime != nil { - sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime) - } default: t.Fatalf("unrecognized session option: %v", name) } @@ -378,8 +372,7 @@ func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionO case "readConcern": txnOpts.SetReadConcern(createReadConcern(opt)) case "maxCommitTimeMS": - t := time.Duration(opt.Int32()) * time.Millisecond - txnOpts.SetMaxCommitTime(&t) + t.Skip("GODRIVER-2348: maxCommitTimeMS is deprecated") default: t.Fatalf("unrecognized transaction option: %v", opt) } @@ -406,9 +399,6 @@ func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConc val := elem.Value() switch key { - case "wtimeout": - wtimeout := convertValueToMilliseconds(t, val) - wc.WTimeout = wtimeout case "j": j := val.Boolean() wc.Journal = &j diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 785044a42f..affa4233df 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -14,7 +14,6 @@ import ( "sync" "sync/atomic" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" @@ -522,7 +521,6 @@ func (t *T) ClearCollections() { // re-instantiating the collection with a majority write concern before dropping. collname := coll.created.Name() wcm := writeconcern.Majority() - wcm.WTimeout = 1 * time.Second wccoll := t.DB.Collection(collname, options.Collection().SetWriteConcern(wcm)) _ = wccoll.Drop(context.Background()) diff --git a/internal/integration/mtest/opmsg_deployment.go b/internal/integration/mtest/opmsg_deployment.go index 6a0a1021c1..bcc10275e6 100644 --- a/internal/integration/mtest/opmsg_deployment.go +++ b/internal/integration/mtest/opmsg_deployment.go @@ -9,6 +9,7 @@ package mtest import ( "context" "errors" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csot" @@ -133,6 +134,12 @@ func (md *mockDeployment) SelectServer(context.Context, description.ServerSelect return md, nil } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for mock deployments. +func (*mockDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (md *mockDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle diff --git a/internal/integration/sdam_error_handling_test.go b/internal/integration/sdam_error_handling_test.go index 091563d2ad..5f0b768cef 100644 --- a/internal/integration/sdam_error_handling_test.go +++ b/internal/integration/sdam_error_handling_test.go @@ -75,9 +75,10 @@ func TestSDAMErrorHandling(t *testing.T) { mt.ResetClient(baseClientOpts(). SetAppName(appName). SetPoolMonitor(tpm.PoolMonitor). - // Set a 100ms socket timeout so that the saslContinue delay of 150ms causes a - // timeout during socket read (i.e. a timeout not caused by the InsertOne context). - SetSocketTimeout(100 * time.Millisecond)) + // Set a 100ms connect timeout so that the saslContinue delay of 150ms + // causes a timeout during a heartbeat (i.e. a timeout not caused by + // the InsertOne context). + SetConnectTimeout(100 * time.Millisecond)) // Use context.Background() so that the new connection will not time out due to an // operation-scoped timeout. diff --git a/internal/integration/sdam_prose_test.go b/internal/integration/sdam_prose_test.go index 69aa7b2dee..d0df936080 100644 --- a/internal/integration/sdam_prose_test.go +++ b/internal/integration/sdam_prose_test.go @@ -205,6 +205,7 @@ func TestServerHeartbeatStartedEvent(t *testing.T) { server := topology.NewServer( address, bson.NewObjectID(), + 1*time.Second, topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return &event.ServerMonitor{ ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { diff --git a/internal/integration/unified/client_entity.go b/internal/integration/unified/client_entity.go index 2d4c87b94b..200b0130b4 100644 --- a/internal/integration/unified/client_entity.go +++ b/internal/integration/unified/client_entity.go @@ -612,7 +612,12 @@ func setClientOptionsFromURIOptions(clientOpts *options.ClientOptions, uriOpts b case "retrywrites": clientOpts.SetRetryWrites(value.(bool)) case "sockettimeoutms": - clientOpts.SetSocketTimeout(time.Duration(value.(int32)) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing socketTimeoutMS (a legacy timeout option + // that we have removed as of v2), then a CSOT analogue exists. Once we + // have ensured an analogue exists, extend "skippedTestDescriptions" to + // avoid this error. + return newSkipTestError("the socketTimeoutMS client option is not supported") case "w": wc.W = value wcSet = true diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index d211013920..3869ecd7e5 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -67,8 +67,6 @@ func executeAggregate(ctx context.Context, operation *operation) (*operationResu return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) - case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) case "maxAwaitTimeMS": opts.SetMaxAwaitTime(time.Duration(val.Int32()) * time.Millisecond) case "pipeline": @@ -194,7 +192,12 @@ func executeCountDocuments(ctx context.Context, operation *operation) (*operatio case "limit": opts.SetLimit(val.Int64()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "skip": opts.SetSkip(int64(val.Int32())) default: @@ -523,7 +526,12 @@ func executeDistinct(ctx context.Context, operation *operation) (*operationResul case "filter": filter = val.Document() case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized distinct option %q", key) } @@ -566,7 +574,12 @@ func executeDropIndex(ctx context.Context, operation *operation) (*operationResu case "name": name = val.StringValue() case "maxTimeMS": - dropIndexOpts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized dropIndex option %q", key) } @@ -589,11 +602,15 @@ func executeDropIndexes(ctx context.Context, operation *operation) (*operationRe elems, _ := operation.Arguments.Elements() for _, elem := range elems { key := elem.Key() - val := elem.Value() switch key { case "maxTimeMS": - dropIndexOpts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized dropIndexes option %q", key) } @@ -654,7 +671,12 @@ func executeEstimatedDocumentCount(ctx context.Context, operation *operation) (* case "comment": opts.SetComment(val) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") default: return nil, fmt.Errorf("unrecognized estimatedDocumentCount option %q", key) } @@ -731,7 +753,12 @@ func executeFindOne(ctx context.Context, operation *operation) (*operationResult } opts.SetHint(hint) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "sort": @@ -790,7 +817,12 @@ func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operat } opts.SetHint(hint) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "sort": @@ -856,7 +888,12 @@ func executeFindOneAndReplace(ctx context.Context, operation *operation) (*opera case "let": opts.SetLet(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "replacement": @@ -940,7 +977,12 @@ func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operat case "let": opts.SetLet(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "projection": opts.SetProjection(val.Document()) case "returnDocument": @@ -1403,7 +1445,12 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, case "max": opts.SetMax(val.Document()) case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS collection option is not supported") case "min": opts.SetMin(val.Document()) case "noCursorTimeout": diff --git a/internal/integration/unified/common_options.go b/internal/integration/unified/common_options.go index 7c34586325..2b78466a9b 100644 --- a/internal/integration/unified/common_options.go +++ b/internal/integration/unified/common_options.go @@ -28,9 +28,8 @@ func (rc *readConcern) toReadConcernOption() *readconcern.ReadConcern { } type writeConcern struct { - Journal *bool `bson:"journal"` - W interface{} `bson:"w"` - WTimeoutMS *int32 `bson:"wtimeoutMS"` + Journal *bool `bson:"journal"` + W interface{} `bson:"w"` } func (wc *writeConcern) toWriteConcernOption() (*writeconcern.WriteConcern, error) { @@ -51,10 +50,6 @@ func (wc *writeConcern) toWriteConcernOption() (*writeconcern.WriteConcern, erro return nil, fmt.Errorf("invalid type for write concern 'w' field %T", wc.W) } } - if wc.WTimeoutMS != nil { - wTimeout := time.Duration(*wc.WTimeoutMS) * time.Millisecond - c.WTimeout = wTimeout - } return c, nil } diff --git a/internal/integration/unified/database_operation_execution.go b/internal/integration/unified/database_operation_execution.go index 269dd7e929..156b2b29a5 100644 --- a/internal/integration/unified/database_operation_execution.go +++ b/internal/integration/unified/database_operation_execution.go @@ -284,7 +284,6 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat batchSize int32 command bson.Raw comment bson.Raw - maxTime time.Duration ) opts := options.RunCmd() @@ -306,7 +305,12 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat case "comment": comment = val.Document() case "maxTimeMS": - maxTime = time.Duration(val.AsInt64()) * time.Millisecond + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS database option is not supported") case "cursorTimeout": return nil, newSkipTestError("cursorTimeout not supported") case "timeoutMode": @@ -329,10 +333,6 @@ func executeRunCursorCommand(ctx context.Context, operation *operation) (*operat cursor.SetBatchSize(batchSize) } - if maxTime > 0 { - cursor.SetMaxTime(maxTime) - } - if len(comment) > 0 { cursor.SetComment(comment) } diff --git a/internal/integration/unified/gridfs_bucket_operation_execution.go b/internal/integration/unified/gridfs_bucket_operation_execution.go index d2ca0f5652..512c582842 100644 --- a/internal/integration/unified/gridfs_bucket_operation_execution.go +++ b/internal/integration/unified/gridfs_bucket_operation_execution.go @@ -13,7 +13,6 @@ import ( "fmt" "io" "math" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -39,7 +38,12 @@ func createBucketFindCursor(ctx context.Context, operation *operation) (*cursorR switch key { case "maxTimeMS": - opts.SetMaxTime(time.Duration(val.Int32()) * time.Millisecond) + // TODO(DRIVERS-2829): Error here instead of skip to ensure that if new + // tests are added containing maxTimeMS (a legacy timeout option that we + // have removed as of v2), then a CSOT analogue exists. Once we have + // ensured an analogue exists, extend "skippedTestDescriptions" to avoid + // this error. + return nil, fmt.Errorf("the maxTimeMS gridfs option is not supported") case "filter": filter = val.Document() default: diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 59aa36ae8c..989e58673c 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -93,7 +93,16 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat // Special handling for the "timeoutMS" field because it applies to (almost) all operations. if tms, ok := op.Arguments.Lookup("timeoutMS").Int32OK(); ok { timeout := time.Duration(tms) * time.Millisecond - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, timeout) + + // Note that a 0-timeout at the operation level is not actually possible + // in Go. This would result in an immediate "context deadline exceeded" + // error. + // + // To achieve an "infinite" case, users would have to rely on either (1) + // defining a 0 timeout at the client-level, or (2) use + // context.Background() at the operation-level. + newCtx, cancelFunc := csot.WithTimeout(ctx, &timeout) + // Redefine ctx to be the new timeout-derived context. ctx = newCtx // Cancel the timeout-derived context at the end of run to avoid a context leak. diff --git a/internal/integration/unified/session_options.go b/internal/integration/unified/session_options.go index 9882073b22..c02865d975 100644 --- a/internal/integration/unified/session_options.go +++ b/internal/integration/unified/session_options.go @@ -8,7 +8,6 @@ package unified import ( "fmt" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -24,11 +23,10 @@ var _ bson.Unmarshaler = (*transactionOptions)(nil) func (to *transactionOptions) UnmarshalBSON(data []byte) error { var temp struct { - RC *readConcern `bson:"readConcern"` - RP *ReadPreference `bson:"readPreference"` - WC *writeConcern `bson:"writeConcern"` - MaxCommitTimeMS *int64 `bson:"maxCommitTimeMS"` - Extra map[string]interface{} `bson:",inline"` + RC *readConcern `bson:"readConcern"` + RP *ReadPreference `bson:"readPreference"` + WC *writeConcern `bson:"writeConcern"` + Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { return fmt.Errorf("error unmarshalling to temporary transactionOptions object: %v", err) @@ -38,10 +36,6 @@ func (to *transactionOptions) UnmarshalBSON(data []byte) error { } to.TransactionOptions = options.Transaction() - if temp.MaxCommitTimeMS != nil { - mctms := time.Duration(*temp.MaxCommitTimeMS) * time.Millisecond - to.SetMaxCommitTime(&mctms) - } if rc := temp.RC; rc != nil { to.SetReadConcern(rc.toReadConcernOption()) } @@ -72,11 +66,10 @@ var _ bson.Unmarshaler = (*sessionOptions)(nil) func (so *sessionOptions) UnmarshalBSON(data []byte) error { var temp struct { - Causal *bool `bson:"causalConsistency"` - MaxCommitTimeMS *int64 `bson:"maxCommitTimeMS"` - TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` - Snapshot *bool `bson:"snapshot"` - Extra map[string]interface{} `bson:",inline"` + Causal *bool `bson:"causalConsistency"` + TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` + Snapshot *bool `bson:"snapshot"` + Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { return fmt.Errorf("error unmarshalling to temporary sessionOptions object: %v", err) @@ -89,10 +82,6 @@ func (so *sessionOptions) UnmarshalBSON(data []byte) error { if temp.Causal != nil { so.SetCausalConsistency(*temp.Causal) } - if temp.MaxCommitTimeMS != nil { - mctms := time.Duration(*temp.MaxCommitTimeMS) * time.Millisecond - so.SetDefaultMaxCommitTime(&mctms) - } if temp.TxnOptions != nil { if rc := temp.TxnOptions.ReadConcern; rc != nil { so.SetDefaultReadConcern(rc) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index a5dbc3e75a..dfa9a124d5 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -436,7 +436,6 @@ func waitForEvent(ctx context.Context, args waitForEventArguments) error { if args.eventCompleted(client) { return nil } - } time.Sleep(100 * time.Millisecond) diff --git a/internal/integration/unified/unified_spec_runner.go b/internal/integration/unified/unified_spec_runner.go index af71caa44a..45f2c164fa 100644 --- a/internal/integration/unified/unified_spec_runner.go +++ b/internal/integration/unified/unified_spec_runner.go @@ -45,6 +45,13 @@ var ( "listSearchIndexes ignores read and write concern": "Sync GODRIVER-3074, but skip testing bug GODRIVER-3043", "updateSearchIndex ignores the read and write concern": "Sync GODRIVER-3074, but skip testing bug GODRIVER-3043", + // TODO(DRIVERS-2829): Create CSOT Legacy Timeout Analogues and Compatibility Field + "Reset server and pool after network timeout error during authentication": "Uses unsupported socketTimeoutMS", + "Ignore network timeout error on find": "Uses unsupported socketTimeoutMS", + "A successful find with options": "Uses unsupported maxTimeMS", + "estimatedDocumentCount with maxTimeMS": "Uses unsupported maxTimeMS", + "supports configuring getMore maxTimeMS": "Uses unsupported maxTimeMS", + // TODO(GODRIVER-3137): Implement Gossip cluster time" "unpin after TransientTransactionError error on commit": "Implement GODRIVER-3137", diff --git a/x/mongo/driver/topology/cancellation_listener.go b/internal/ptrutil/ptr.go similarity index 56% rename from x/mongo/driver/topology/cancellation_listener.go rename to internal/ptrutil/ptr.go index caca988057..bf64aad178 100644 --- a/x/mongo/driver/topology/cancellation_listener.go +++ b/internal/ptrutil/ptr.go @@ -1,14 +1,12 @@ -// Copyright (C) MongoDB, Inc. 2017-present. +// Copyright (C) MongoDB, Inc. 2024-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package topology +package ptrutil -import "context" - -type cancellationListener interface { - Listen(context.Context, func()) - StopListening() bool +// Ptr will return the memory location of the given value. +func Ptr[T any](val T) *T { + return &val } diff --git a/mongo/batch_cursor.go b/mongo/batch_cursor.go index 9e87b00ae4..a50fa899cf 100644 --- a/mongo/batch_cursor.go +++ b/mongo/batch_cursor.go @@ -40,13 +40,13 @@ type batchCursor interface { // the cursor that implements it. SetBatchSize(int32) - // SetMaxTime will set the maximum amount of time the server will allow + // SetMaxAwaitTime will set the maximum amount of time the server will allow // the operations to execute. The server will error if this field is set // but the cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and // rounded down to the nearest millisecond. - SetMaxTime(time.Duration) + SetMaxAwaitTime(time.Duration) // SetComment will set a user-configurable comment that can be used to // identify the operation in server logs. diff --git a/mongo/change_stream.go b/mongo/change_stream.go index cc051b5f08..f02010f53f 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -151,6 +151,33 @@ func mergeChangeStreamOptions(opts ...*options.ChangeStreamOptions) *options.Cha return csOpts } +// validChangeStreamTimeouts will return "false" if maxAwaitTimeMS is set, +// timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or +// equal to timeoutMS. Otherwise, the timeouts are valid. +func validChangeStreamTimeouts(ctx context.Context, cs *ChangeStream) bool { + if cs.options == nil || cs.client == nil { + return true + } + + maxAwaitTime := cs.options.MaxAwaitTime + timeout := cs.client.timeout + + if maxAwaitTime == nil { + return true + } + + if deadline, ok := ctx.Deadline(); ok { + ctxTimeout := time.Until(deadline) + timeout = &ctxTimeout + } + + if timeout == nil { + return true + } + + return *timeout <= 0 || *maxAwaitTime < *timeout +} + func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { if ctx == nil { @@ -161,12 +188,14 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cursorOpts.MarshalValueEncoderFn = newEncoderFn(config.bsonOpts, config.registry) + changeStreamOpts := mergeChangeStreamOptions(opts...) + cs := &ChangeStream{ client: config.client, bsonOpts: config.bsonOpts, registry: config.registry, streamType: config.streamType, - options: mergeChangeStreamOptions(opts...), + options: changeStreamOpts, selector: &serverselector.Composite{ Selectors: []description.ServerSelector{ &serverselector.ReadPref{ReadPref: config.readPreference}, @@ -208,7 +237,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cs.cursorOptions.BatchSize = *cs.options.BatchSize } if cs.options.MaxAwaitTime != nil { - cs.cursorOptions.MaxTimeMS = int64(*cs.options.MaxAwaitTime / time.Millisecond) + cs.cursorOptions.SetMaxAwaitTime(*cs.options.MaxAwaitTime) } if cs.options.Custom != nil { // Marshal all custom options before passing to the initial aggregate. Return @@ -297,10 +326,18 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err var server driver.Server var conn *mnet.Connection - if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { + // Apply the client-level timeout if the operation-level timeout is not set. + ctx, cancel := csot.WithTimeout(ctx, cs.client.timeout) + defer cancel() + + connCtx, cancel := csot.WithServerSelectionTimeout(ctx, cs.client.deployment.GetServerSelectionTimeout()) + defer cancel() + + if server, cs.err = cs.client.deployment.SelectServer(connCtx, cs.selector); cs.err != nil { return cs.Err() } - if conn, cs.err = server.Connection(ctx); cs.err != nil { + + if conn, cs.err = server.Connection(connCtx); cs.err != nil { return cs.Err() } defer conn.Close() @@ -329,17 +366,6 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err cs.aggregate.Pipeline(plArr) } - // If no deadline is set on the passed-in context, cs.client.timeout is set, and context is not already - // a Timeout context, honor cs.client.timeout in new Timeout context for change stream operation execution - // and potential retry. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *cs.client.timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of executeOperation to avoid a context leak. - defer cancelFunc() - } - // Execute the aggregate, retrying on retryable errors once (1) if retryable reads are enabled and // infinitely (-1) if context is a Timeout context. var retries int @@ -366,16 +392,20 @@ AggregateExecuteLoop: break AggregateExecuteLoop } + connCtx, cancel := csot.WithServerSelectionTimeout(ctx, cs.client.deployment.GetServerSelectionTimeout()) + defer cancel() + // If error is retryable: subtract 1 from retries, redo server selection, checkout // a connection, and restart loop. retries-- - server, err = cs.client.deployment.SelectServer(ctx, cs.selector) + server, err = cs.client.deployment.SelectServer(connCtx, cs.selector) if err != nil { break AggregateExecuteLoop } conn.Close() - conn, err = server.Connection(ctx) + + conn, err = server.Connection(connCtx) if err != nil { break AggregateExecuteLoop } @@ -646,26 +676,35 @@ func (cs *ChangeStream) ResumeToken() bson.Raw { return cs.resumeToken } -// Next gets the next event for this change stream. It returns true if there were no errors and the next event document -// is available. +// Next gets the next event for this change stream. It returns true if there +// were no errors and the next event document is available. // -// Next blocks until an event is available, an error occurs, or ctx expires. If ctx expires, the error -// will be set to ctx.Err(). In an error case, Next will return false. +// Next blocks until an event is available, an error occurs, or ctx expires. +// If ctx expires, the error will be set to ctx.Err(). In an error case, Next +// will return false. // // If Next returns false, subsequent calls will also return false. func (cs *ChangeStream) Next(ctx context.Context) bool { return cs.next(ctx, false) } -// TryNext attempts to get the next event for this change stream. It returns true if there were no errors and the next -// event document is available. +// TryNext attempts to get the next event for this change stream. It returns +// true if there were no errors and the next event document is available. +// +// TryNext returns false if the change stream is closed by the server, an error +// occurs when getting changes from the server, the next change is not yet +// available, or ctx expires. // -// TryNext returns false if the change stream is closed by the server, an error occurs when getting changes from the -// server, the next change is not yet available, or ctx expires. If ctx expires, the error will be set to ctx.Err(). +// If ctx expires, the error will be set to ctx.Err(). Users can either call +// TryNext again or close the existing change stream and create a new one. It is +// suggested to close and re-create the stream with ah higher timeout if the +// timeout occurs before any events have been received, which is a signal that +// the server is timing out before it can finish processing the existing oplog. // -// If TryNext returns false and an error occurred or the change stream was closed -// (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also return false. Otherwise, it is safe to call -// TryNext again until a change is available. +// If TryNext returns false and an error occurred or the change stream was +// closed (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also +// return false. Otherwise, it is safe to call TryNext again until a change is +// available. // // This method requires driver version >= 1.2.0. func (cs *ChangeStream) TryNext(ctx context.Context) bool { @@ -703,6 +742,18 @@ func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { } func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { + if !validChangeStreamTimeouts(ctx, cs) { + cs.err = fmt.Errorf("MaxAwaitTime must be less than the operation timeout") + + return + } + + // Apply the client-level timeout if the operation-level timeout is not set. + // This calculation is also done in "executeOperation" but cursor.Next is also + // blocking and should honor client-level timeouts. + ctx, cancel := csot.WithTimeout(ctx, cs.client.timeout) + defer cancel() + for { if cs.cursor == nil { return diff --git a/mongo/change_stream_deployment.go b/mongo/change_stream_deployment.go index 64f30095c8..b4fdbd2690 100644 --- a/mongo/change_stream_deployment.go +++ b/mongo/change_stream_deployment.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "time" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" @@ -25,6 +26,7 @@ var _ driver.Server = (*changeStreamDeployment)(nil) var _ driver.ErrorProcessor = (*changeStreamDeployment)(nil) func (c *changeStreamDeployment) SelectServer(context.Context, description.ServerSelector) (driver.Server, error) { + return c, nil } @@ -48,3 +50,9 @@ func (c *changeStreamDeployment) ProcessError(err error, describer mnet.Describe return ep.ProcessError(err, describer) } + +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for change stream deployments. +func (*changeStreamDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index fa44713593..5b1193a0da 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -7,7 +7,9 @@ package mongo import ( + "context" "testing" + "time" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/mongo/options" @@ -15,7 +17,7 @@ import ( func TestChangeStream(t *testing.T) { t.Run("nil cursor", func(t *testing.T) { - cs := &ChangeStream{} + cs := &ChangeStream{client: &Client{}} id := cs.ID() assert.Equal(t, int64(0), id, "expected ID 0, got %v", id) @@ -90,3 +92,96 @@ func TestMergeChangeStreamOptions(t *testing.T) { }) } } + +func TestValidChangeStreamTimeouts(t *testing.T) { + t.Parallel() + + newDurPtr := func(dur time.Duration) *time.Duration { + return &dur + } + + tests := []struct { + name string + parent context.Context + maxAwaitTimeout, timeout *time.Duration + wantTimeout time.Duration + want bool + }{ + { + name: "no context deadline and no timeouts", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTimeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: nil, + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and timeout", + parent: context.Background(), + maxAwaitTimeout: nil, + timeout: newDurPtr(1), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime gt timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(2), + timeout: newDurPtr(1), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime lt timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(2), + wantTimeout: 0, + want: true, + }, + { + name: "no context deadline and maxAwaitTime eq timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(1), + wantTimeout: 0, + want: false, + }, + { + name: "no context deadline and maxAwaitTime with negative timeout", + parent: context.Background(), + maxAwaitTimeout: newDurPtr(1), + timeout: newDurPtr(-1), + wantTimeout: 0, + want: true, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cs := &ChangeStream{ + options: &options.ChangeStreamOptions{ + MaxAwaitTime: test.maxAwaitTimeout, + }, + client: &Client{ + timeout: test.timeout, + }, + } + + got := validChangeStreamTimeouts(test.parent, cs) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/mongo/client.go b/mongo/client.go index c6b6f8174d..d3e00bef17 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -180,7 +180,6 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if clientOpt.RetryReads != nil { client.retryReads = *clientOpt.RetryReads } - // Timeout client.timeout = clientOpt.Timeout client.httpClient = clientOpt.HTTPClient // WriteConcern @@ -210,7 +209,13 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if err != nil { return nil, err } - client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts) + + var connectTimeout time.Duration + if clientOpt.ConnectTimeout != nil { + connectTimeout = *clientOpt.ConnectTimeout + } + + client.serverAPI = topology.ServerAPIFromServerOptions(connectTimeout, cfg.ServerOpts) if client.deployment == nil { client.deployment, err = topology.New(cfg) @@ -392,9 +397,6 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) if opt.DefaultWriteConcern != nil { sopts.DefaultWriteConcern = opt.DefaultWriteConcern } - if opt.DefaultMaxCommitTime != nil { - sopts.DefaultMaxCommitTime = opt.DefaultMaxCommitTime - } if opt.Snapshot != nil { sopts.Snapshot = opt.Snapshot } @@ -419,9 +421,6 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) if sopts.DefaultReadPreference != nil { coreOpts.DefaultReadPreference = sopts.DefaultReadPreference } - if sopts.DefaultMaxCommitTime != nil { - coreOpts.DefaultMaxCommitTime = sopts.DefaultMaxCommitTime - } if sopts.Snapshot != nil { coreOpts.Snapshot = sopts.Snapshot } diff --git a/mongo/client_test.go b/mongo/client_test.go index e5d08642b3..6e7607be91 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -501,4 +501,13 @@ func TestClient(t *testing.T) { }) } }) + t.Run("negative timeout will err", func(t *testing.T) { + t.Parallel() + + copts := options.Client().SetTimeout(-1 * time.Second) + _, err := Connect(copts) + + errmsg := `invalid value "-1s" for "Timeout": value must be positive` + assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error()) + }) } diff --git a/mongo/collection.go b/mongo/collection.go index c24ab9273f..ea75a4a0cd 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -945,9 +945,6 @@ func mergeAggregateOptions(opts ...*options.AggregateOptions) *options.Aggregate if ao.Collation != nil { aggOpts.Collation = ao.Collation } - if ao.MaxTime != nil { - aggOpts.MaxTime = ao.MaxTime - } if ao.MaxAwaitTime != nil { aggOpts.MaxAwaitTime = ao.MaxAwaitTime } @@ -1032,8 +1029,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { Crypt(a.client.cryptFLE). ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). - Timeout(a.client.timeout). - MaxTime(ao.MaxTime) + Timeout(a.client.timeout) if ao.AllowDiskUse != nil { op.AllowDiskUse(*ao.AllowDiskUse) @@ -1050,7 +1046,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { op.Collation(bsoncore.Document(ao.Collation.ToDocument())) } if ao.MaxAwaitTime != nil { - cursorOpts.MaxTimeMS = int64(*ao.MaxAwaitTime / time.Millisecond) + cursorOpts.SetMaxAwaitTime(*ao.MaxAwaitTime) } if ao.Comment != nil { comment, err := marshalValue(ao.Comment, a.bsonOpts, a.registry) @@ -1147,9 +1143,6 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, if co.Limit != nil { countOpts.Limit = co.Limit } - if co.MaxTime != nil { - countOpts.MaxTime = co.MaxTime - } if co.Skip != nil { countOpts.Skip = co.Skip } @@ -1178,7 +1171,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) + Timeout(coll.client.timeout) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -1269,16 +1262,13 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, if opt.Comment != nil { co.Comment = opt.Comment } - if opt.MaxTime != nil { - co.MaxTime = opt.MaxTime - } } selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(co.MaxTime) + Timeout(coll.client.timeout) if co.Comment != nil { comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) @@ -1352,9 +1342,6 @@ func (coll *Collection) Distinct( if do.Comment != nil { option.Comment = do.Comment } - if do.MaxTime != nil { - option.MaxTime = do.MaxTime - } } op := operation.NewDistinct(fieldName, f). @@ -1362,7 +1349,7 @@ func (coll *Collection) Distinct( Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(option.MaxTime) + Timeout(coll.client.timeout) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1439,9 +1426,6 @@ func mergeFindOptions(opts ...*options.FindOptions) *options.FindOptions { if opt.MaxAwaitTime != nil { fo.MaxAwaitTime = opt.MaxAwaitTime } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Min != nil { fo.Min = opt.Min } @@ -1517,7 +1501,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger) cursorOpts := coll.client.createBaseCursorOptions() @@ -1588,7 +1572,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Max(max) } if fo.MaxAwaitTime != nil { - cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond) + cursorOpts.SetMaxAwaitTime(*fo.MaxAwaitTime) } if fo.Min != nil { min, err := marshal(fo.Min, coll.bsonOpts, coll.registry) @@ -1656,7 +1640,6 @@ func newFindOptionsFromFindOneOptions(opts ...*options.FindOneOptions) []*option Comment: opt.Comment, Hint: opt.Hint, Max: opt.Max, - MaxTime: opt.MaxTime, Min: opt.Min, Projection: opt.Projection, ReturnKey: opt.ReturnKey, @@ -1769,9 +1752,6 @@ func mergeFindOneAndDeleteOptions(opts ...*options.FindOneAndDeleteOptions) *opt if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -1808,8 +1788,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} return &SingleResult{err: err} } fod := mergeFindOneAndDeleteOptions(opts...) - op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fod.MaxTime) + op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1875,9 +1854,6 @@ func mergeFindOneAndReplaceOptions(opts ...*options.FindOneAndReplaceOptions) *o if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -1932,7 +1908,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := mergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsoncore.TypeEmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -2010,9 +1986,6 @@ func mergeFindOneAndUpdateOptions(opts ...*options.FindOneAndUpdateOptions) *opt if opt.Comment != nil { fo.Comment = opt.Comment } - if opt.MaxTime != nil { - fo.MaxTime = opt.MaxTime - } if opt.Projection != nil { fo.Projection = opt.Projection } @@ -2064,8 +2037,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} } fo := mergeFindOneAndUpdateOptions(opts...) - op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fo.MaxTime) + op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index 9ef1a63acd..658ff2451c 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -201,7 +201,7 @@ func ExampleCollection_Aggregate() { }}, }}, } - opts := options.Aggregate().SetMaxTime(2 * time.Second) + opts := options.Aggregate() cursor, err := coll.Aggregate( context.TODO(), mongo.Pipeline{groupStage}, @@ -264,14 +264,13 @@ func ExampleCollection_BulkWrite() { func ExampleCollection_CountDocuments() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Count the number of times the name "Bob" appears in the collection. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. - opts := options.Count().SetMaxTime(2 * time.Second) - count, err := coll.CountDocuments( - context.TODO(), - bson.D{{"name", "Bob"}}, - opts) + count, err := coll.CountDocuments(ctx, bson.D{{"name", "Bob"}}, nil) if err != nil { log.Fatal(err) } @@ -317,13 +316,15 @@ func ExampleCollection_DeleteOne() { func ExampleCollection_Distinct() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Find all unique values for the "name" field for documents in which the // "age" field is greater than 25. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. filter := bson.D{{"age", bson.D{{"$gt", 25}}}} - opts := options.Distinct().SetMaxTime(2 * time.Second) - res := coll.Distinct(context.TODO(), "name", filter, opts) + res := coll.Distinct(ctx, "name", filter) if err := res.Err(); err != nil { log.Fatal(err) } @@ -341,11 +342,13 @@ func ExampleCollection_Distinct() { func ExampleCollection_EstimatedDocumentCount() { var coll *mongo.Collection + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + // Get and print an estimated of the number of documents in the collection. - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server. - opts := options.EstimatedDocumentCount().SetMaxTime(2 * time.Second) - count, err := coll.EstimatedDocumentCount(context.TODO(), opts) + count, err := coll.EstimatedDocumentCount(ctx, nil) if err != nil { log.Fatal(err) } @@ -1053,8 +1056,7 @@ func ExampleIndexView_CreateMany() { // Specify the MaxTime option to limit the amount of time the operation can // run on the server - opts := options.CreateIndexes().SetMaxTime(2 * time.Second) - names, err := indexView.CreateMany(context.TODO(), models, opts) + names, err := indexView.CreateMany(context.TODO(), models, nil) if err != nil { log.Fatal(err) } @@ -1065,17 +1067,19 @@ func ExampleIndexView_CreateMany() { func ExampleIndexView_List() { var indexView *mongo.IndexView - // Specify the MaxTime option to limit the amount of time the operation can - // run on the server - opts := options.ListIndexes().SetMaxTime(2 * time.Second) - cursor, err := indexView.List(context.TODO(), opts) + // Specify a timeout to limit the amount of time the operation can run on + // the server. + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + cursor, err := indexView.List(ctx, nil) if err != nil { log.Fatal(err) } // Get a slice of all indexes returned and print them out. var results []bson.M - if err = cursor.All(context.TODO(), &results); err != nil { + if err = cursor.All(ctx, &results); err != nil { log.Fatal(err) } fmt.Println(results) diff --git a/mongo/cursor.go b/mongo/cursor.go index 8f07b1ee9b..1ae74e1a76 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -394,14 +394,14 @@ func (c *Cursor) SetBatchSize(batchSize int32) { c.bc.SetBatchSize(batchSize) } -// SetMaxTime will set the maximum amount of time the server will allow the +// SetMaxAwaitTime will set the maximum amount of time the server will allow the // operations to execute. The server will error if this field is set but the // cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and rounded // down to the nearest millisecond. -func (c *Cursor) SetMaxTime(dur time.Duration) { - c.bc.SetMaxTime(dur) +func (c *Cursor) SetMaxAwaitTime(dur time.Duration) { + c.bc.SetMaxAwaitTime(dur) } // SetComment will set a user-configurable comment that can be used to identify diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go index be877b7a6c..45a3247b15 100644 --- a/mongo/cursor_test.go +++ b/mongo/cursor_test.go @@ -95,9 +95,9 @@ func (tbc *testBatchCursor) Close(context.Context) error { return nil } -func (tbc *testBatchCursor) SetBatchSize(int32) {} -func (tbc *testBatchCursor) SetComment(interface{}) {} -func (tbc *testBatchCursor) SetMaxTime(time.Duration) {} +func (tbc *testBatchCursor) SetBatchSize(int32) {} +func (tbc *testBatchCursor) SetComment(interface{}) {} +func (tbc *testBatchCursor) SetMaxAwaitTime(time.Duration) {} func TestCursor(t *testing.T) { t.Run("loops until docs available", func(t *testing.T) {}) diff --git a/mongo/database.go b/mongo/database.go index 4748d3d2b0..36296a11b7 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csfle" + "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -727,10 +728,14 @@ func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, nam // That is OK. This wire version check is a best effort to inform users earlier if using a QEv2 driver with a QEv1 server. { const QEv2WireVersion = 21 + ctx, cancel := csot.WithServerSelectionTimeout(ctx, db.client.deployment.GetServerSelectionTimeout()) + defer cancel() + server, err := db.client.deployment.SelectServer(ctx, &serverselector.Write{}) if err != nil { return fmt.Errorf("error selecting server to check maxWireVersion: %w", err) } + conn, err := server.Connection(ctx) if err != nil { return fmt.Errorf("error getting connection to check maxWireVersion: %w", err) diff --git a/mongo/errors.go b/mongo/errors.go index f3e7bbd43d..5b2c039898 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -124,7 +124,6 @@ func IsDuplicateKeyError(err error) bool { var timeoutErrs = [...]error{ context.DeadlineExceeded, driver.ErrDeadlineWouldBeExceeded, - topology.ErrServerSelectionTimeout, } // IsTimeout returns true if err was caused by a timeout. For error chains, diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index e5016a5179..dd3661877b 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -61,7 +61,8 @@ type upload struct { // filename. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenUploadStream( ctx context.Context, filename string, @@ -74,13 +75,17 @@ func (b *GridFSBucket) OpenUploadStream( // ID and filename. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenUploadStreamWithID( ctx context.Context, fileID interface{}, filename string, opts ...*options.UploadOptions, ) (*GridFSUploadStream, error) { + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + if err := b.checkFirstWrite(ctx); err != nil { return nil, err } @@ -100,7 +105,8 @@ func (b *GridFSBucket) OpenUploadStreamWithID( // bucket that also require a custom deadline. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) UploadFromStream( ctx context.Context, filename string, @@ -119,7 +125,8 @@ func (b *GridFSBucket) UploadFromStream( // bucket that also require a custom deadline. // // The context provided to this method controls the entire lifetime of an -// upload stream io.Writer. +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) UploadFromStreamWithID( ctx context.Context, fileID interface{}, @@ -157,8 +164,9 @@ func (b *GridFSBucket) UploadFromStreamWithID( // OpenDownloadStream creates a stream from which the contents of the file can // be read. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenDownloadStream(ctx context.Context, fileID interface{}) (*GridFSDownloadStream, error) { return b.openDownloadStream(ctx, bson.D{{"_id", fileID}}) } @@ -171,8 +179,9 @@ func (b *GridFSBucket) OpenDownloadStream(ctx context.Context, fileID interface{ // cannot be done concurrently with other read operations operations on this // bucket that also require a custom deadline. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) DownloadToStream(ctx context.Context, fileID interface{}, stream io.Writer) (int64, error) { ds, err := b.OpenDownloadStream(ctx, fileID) if err != nil { @@ -185,8 +194,9 @@ func (b *GridFSBucket) DownloadToStream(ctx context.Context, fileID interface{}, // OpenDownloadStreamByName opens a download stream for the file with the given // filename. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) OpenDownloadStreamByName( ctx context.Context, filename string, @@ -227,8 +237,9 @@ func (b *GridFSBucket) OpenDownloadStreamByName( // cannot be done concurrently with other read operations operations on this // bucket that also require a custom deadline. // -// The context provided to this method controls the entire lifetime of a -// download stream io.Reader. +// The context provided to this method controls the entire lifetime of an +// upload stream io.Writer. If the context does set a deadline, then the +// client-level timeout will be used to cap the lifetime of the stream. func (b *GridFSBucket) DownloadToStreamByName( ctx context.Context, filename string, @@ -243,23 +254,13 @@ func (b *GridFSBucket) DownloadToStreamByName( return b.downloadToStream(ds, stream) } -// Delete deletes all chunks and metadata associated with the file with the given file ID and runs the underlying -// delete operations with the provided context. -// -// Use the context parameter to time-out or cancel the delete operation. The deadline set by SetWriteDeadline is ignored. +// Delete deletes all chunks and metadata associated with the file with the +// given file ID and runs the underlying delete operations with the provided +// context. func (b *GridFSBucket) Delete(ctx context.Context, fileID interface{}) error { - // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is - // not already a Timeout context, honor Timeout in new Timeout context for operation execution to - // be shared by both delete operations. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } - - // Delete document in files collection and then chunks to minimize race conditions. + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + res, err := b.filesColl.DeleteOne(ctx, bson.D{{"_id", fileID}}) if err == nil && res.DeletedCount == 0 { err = ErrFileNotFound @@ -272,11 +273,8 @@ func (b *GridFSBucket) Delete(ctx context.Context, fileID interface{}) error { return b.deleteChunks(ctx, fileID) } -// Find returns the files collection documents that match the given filter and runs the underlying -// find query with the provided context. -// -// Use the context parameter to time-out or cancel the find operation. The deadline set by SetReadDeadline -// is ignored. +// Find returns the files collection documents that match the given filter and +// runs the underlying find query with the provided context. func (b *GridFSBucket) Find( ctx context.Context, filter interface{}, @@ -296,9 +294,6 @@ func (b *GridFSBucket) Find( if opt.Limit != nil { gfsOpts.Limit = opt.Limit } - if opt.MaxTime != nil { - gfsOpts.MaxTime = opt.MaxTime - } if opt.NoCursorTimeout != nil { gfsOpts.NoCursorTimeout = opt.NoCursorTimeout } @@ -319,9 +314,6 @@ func (b *GridFSBucket) Find( if gfsOpts.Limit != nil { find.SetLimit(int64(*gfsOpts.Limit)) } - if gfsOpts.MaxTime != nil { - find.SetMaxTime(*gfsOpts.MaxTime) - } if gfsOpts.NoCursorTimeout != nil { find.SetNoCursorTimeout(*gfsOpts.NoCursorTimeout) } @@ -336,11 +328,6 @@ func (b *GridFSBucket) Find( } // Rename renames the stored file with the specified file ID. -// -// If this operation requires a custom write deadline to be set on the bucket, it cannot be done concurrently with other -// write operations operations on this bucket that also require a custom deadline -// -// Use SetWriteDeadline to set a deadline for the rename operation. func (b *GridFSBucket) Rename(ctx context.Context, fileID interface{}, newFilename string) error { res, err := b.filesColl.UpdateOne(ctx, bson.D{{"_id", fileID}}, @@ -357,21 +344,11 @@ func (b *GridFSBucket) Rename(ctx context.Context, fileID interface{}, newFilena return nil } -// Drop drops the files and chunks collections associated with this bucket and runs the drop operations with -// the provided context. -// -// Use the context parameter to time-out or cancel the drop operation. The deadline set by SetWriteDeadline is ignored. +// Drop drops the files and chunks collections associated with this bucket and +// runs the drop operations with the provided context. func (b *GridFSBucket) Drop(ctx context.Context) error { - // If no deadline is set on the passed-in context, Timeout is set on the Client, and context is - // not already a Timeout context, honor Timeout in new Timeout context for operation execution to - // be shared by both drop operations. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && b.db.Client().timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *b.db.Client().timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() err := b.filesColl.Drop(ctx) if err != nil { @@ -396,6 +373,9 @@ func (b *GridFSBucket) openDownloadStream( filter interface{}, opts ...*options.FindOneOptions, ) (*GridFSDownloadStream, error) { + ctx, cancel := csot.WithTimeout(ctx, b.db.client.timeout) + defer cancel() + result := b.filesColl.FindOne(ctx, filter, opts...) // Unmarshal the data into a File instance, which can be passed to newGridFSDownloadStream. The _id value has to be @@ -425,6 +405,7 @@ func (b *GridFSBucket) openDownloadStream( if err != nil { return nil, err } + // The chunk size can be overridden for individual files, so the expected chunk size should be the "chunkSize" // field from the files collection document, not the bucket's chunk size. return newGridFSDownloadStream(ctx, chunksCursor, foundFile.ChunkSize, foundFile), nil diff --git a/mongo/index_view.go b/mongo/index_view.go index 84f4d71dc4..59ea8c8e26 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -108,15 +108,12 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption if opt.BatchSize != nil { lio.BatchSize = opt.BatchSize } - if opt.MaxTime != nil { - lio.MaxTime = opt.MaxTime - } } if lio.BatchSize != nil { op = op.BatchSize(*lio.BatchSize) cursorOpts.BatchSize = *lio.BatchSize } - op = op.MaxTime(lio.MaxTime) + retry := driver.RetryNone if iv.coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -269,9 +266,6 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. if opt == nil { continue } - if opt.MaxTime != nil { - option.MaxTime = opt.MaxTime - } if opt.CommitQuorum != nil { option.CommitQuorum = opt.CommitQuorum } @@ -281,7 +275,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) if option.CommitQuorum != nil { commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -383,7 +377,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum return optsDoc, nil } -func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { +func (iv IndexView) drop(ctx context.Context, name string, _ ...*options.DropIndexesOptions) (bson.Raw, error) { if ctx == nil { ctx = context.Background() } @@ -409,21 +403,12 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop selector := makePinnedSelector(sess, iv.coll.writeSelector) - dio := options.DropIndexes() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.MaxTime != nil { - dio.MaxTime = opt.MaxTime - } - } op := operation.NewDropIndexes(name). Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime).Crypt(iv.coll.client.cryptFLE) + Timeout(iv.coll.client.timeout).Crypt(iv.coll.client.cryptFLE) err = op.Execute(ctx) if err != nil { diff --git a/mongo/options/aggregateoptions.go b/mongo/options/aggregateoptions.go index 6a8c26faab..2c068a582e 100644 --- a/mongo/options/aggregateoptions.go +++ b/mongo/options/aggregateoptions.go @@ -32,14 +32,6 @@ type AggregateOptions struct { // default value is nil, which means the default collation of the collection will be used. Collation *Collation - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // The maximum amount of time that the server should wait for new documents to satisfy a tailable cursor query. // This option is only valid for MongoDB versions >= 3.2 and is ignored for previous server versions. MaxAwaitTime *time.Duration @@ -95,16 +87,6 @@ func (ao *AggregateOptions) SetCollation(c *Collation) *AggregateOptions { return ao } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (ao *AggregateOptions) SetMaxTime(d time.Duration) *AggregateOptions { - ao.MaxTime = &d - return ao -} - // SetMaxAwaitTime sets the value for the MaxAwaitTime field. func (ao *AggregateOptions) SetMaxAwaitTime(d time.Duration) *AggregateOptions { ao.MaxAwaitTime = &d diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index b5fe1931ea..ccd068029f 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -251,13 +251,6 @@ type ClientOptions struct { // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any // release. Deployment driver.Deployment - - // SocketTimeout specifies the timeout to be used for the Client's socket reads and writes. - // - // NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option - // may be used in its place to control the amount of time that a single operation can run before returning - // an error. Setting SocketTimeout and Timeout on a single client will result in undefined behavior. - SocketTimeout *time.Duration } // Client creates a new ClientOptions instance. @@ -325,6 +318,10 @@ func (c *ClientOptions) validate() error { return fmt.Errorf("invalid server monitoring mode: %q", *mode) } + if to := c.Timeout; to != nil && *to < 0 { + return fmt.Errorf(`invalid value %q for "Timeout": value must be positive`, *to) + } + return nil } @@ -478,10 +475,6 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { c.ServerSelectionTimeout = &cs.ServerSelectionTimeout } - if cs.SocketTimeoutSet { - c.SocketTimeout = &cs.SocketTimeout - } - if cs.SRVMaxHosts != 0 { c.SRVMaxHosts = &cs.SRVMaxHosts } @@ -531,7 +524,7 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { c.TLSConfig = tlsConfig } - if cs.JSet || cs.WString != "" || cs.WNumberSet || cs.WTimeoutSet { + if cs.JSet || cs.WString != "" || cs.WNumberSet { c.WriteConcern = &writeconcern.WriteConcern{} if len(cs.WString) > 0 { @@ -543,10 +536,6 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { if cs.JSet { c.WriteConcern.Journal = &cs.J } - - if cs.WTimeoutSet { - c.WriteConcern.WTimeout = cs.WTimeout - } } if cs.ZlibLevelSet { @@ -831,29 +820,19 @@ func (c *ClientOptions) SetServerSelectionTimeout(d time.Duration) *ClientOption return c } -// SetSocketTimeout specifies how long the driver will wait for a socket read or write to return before returning a -// network error. This can also be set through the "socketTimeoutMS" URI option (e.g. "socketTimeoutMS=1000"). The -// default value is 0, meaning no timeout is used and socket operations can block indefinitely. -// -// NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option may be used -// in its place to control the amount of time that a single operation can run before returning an error. Setting -// SocketTimeout and Timeout on a single client will result in undefined behavior. -func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions { - c.SocketTimeout = &d - return c -} - -// SetTimeout specifies the amount of time that a single operation run on this Client can execute before returning an error. -// The deadline of any operation run through the Client will be honored above any Timeout set on the Client; Timeout will only -// be honored if there is no deadline on the operation Context. Timeout can also be set through the "timeoutMS" URI option -// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not inherit a timeout from the Client. +// SetTimeout specifies the amount of time that a single operation run on this +// Client can execute before returning an error. The deadline of any operation +// run through the Client will be honored above any Timeout set on the Client; +// Timeout will only be honored if there is no deadline on the operation +// Context. Timeout can also be set through the "timeoutMS" URI option +// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not +// inherit a timeout from the Client. // -// If any Timeout is set (even 0) on the Client, the values of MaxTime on operation options, TransactionOptions.MaxCommitTime and -// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and SocketTimeout or WriteConcern.wTimeout will result -// in undefined behavior. +// The value for a Timeout must be positive. // -// NOTE(benjirewis): SetTimeout represents unstable, provisional API. The behavior of the driver when a Timeout is specified is -// subject to change. +// If any Timeout is set (even 0) on the Client, the values of MaxTime on +// operation options, TransactionOptions.MaxCommitTime and +// SessionOptions.DefaultMaxCommitTime will be ignored. func (c *ClientOptions) SetTimeout(d time.Duration) *ClientOptions { c.Timeout = &d return c @@ -1088,9 +1067,6 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.Direct != nil { c.Direct = opt.Direct } - if opt.SocketTimeout != nil { - c.SocketTimeout = opt.SocketTimeout - } if opt.SRVMaxHosts != nil { c.SRVMaxHosts = opt.SRVMaxHosts } diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index beba45514f..70131ded57 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -85,7 +85,6 @@ func TestClientOptions(t *testing.T) { {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, {"Direct", (*ClientOptions).SetDirect, true, "Direct", true}, - {"SocketTimeout", (*ClientOptions).SetSocketTimeout, 5 * time.Second, "SocketTimeout", true}, {"TLSConfig", (*ClientOptions).SetTLSConfig, &tls.Config{}, "TLSConfig", false}, {"WriteConcern", (*ClientOptions).SetWriteConcern, writeconcern.Majority(), "WriteConcern", false}, {"ZlibLevel", (*ClientOptions).SetZlibLevel, 6, "ZlibLevel", true}, @@ -390,11 +389,6 @@ func TestClientOptions(t *testing.T) { "mongodb://localhost/?serverSelectionTimeoutMS=45000", baseClient().SetServerSelectionTimeout(45 * time.Second), }, - { - "SocketTimeout", - "mongodb://localhost/?socketTimeoutMS=15000", - baseClient().SetSocketTimeout(15 * time.Second), - }, { "TLS CACertificate", "mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/ca.pem", @@ -440,11 +434,6 @@ func TestClientOptions(t *testing.T) { "mongodb://localhost/?w=3", baseClient().SetWriteConcern(&writeconcern.WriteConcern{W: 3}), }, - { - "WriteConcern WTimeout", - "mongodb://localhost/?wTimeoutMS=45000", - baseClient().SetWriteConcern(&writeconcern.WriteConcern{WTimeout: 45 * time.Second}), - }, { "ZLibLevel", "mongodb://localhost/?zlibCompressionLevel=4", diff --git a/mongo/options/countoptions.go b/mongo/options/countoptions.go index a47550f6d2..7321bb2743 100644 --- a/mongo/options/countoptions.go +++ b/mongo/options/countoptions.go @@ -6,8 +6,6 @@ package options -import "time" - // CountOptions represents options that can be used to configure a CountDocuments operation. type CountOptions struct { // Specifies a collation to use for string comparisons during the operation. This option is only valid for MongoDB @@ -28,14 +26,6 @@ type CountOptions struct { // documents matching the filter will be counted. Limit *int64 - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there is - // no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used in - // its place to control the amount of time that a single operation can run before returning an error. MaxTime is - // ignored if Timeout is set on the client. - MaxTime *time.Duration - // The number of documents to skip before counting. The default value is 0. Skip *int64 } @@ -69,16 +59,6 @@ func (co *CountOptions) SetLimit(i int64) *CountOptions { return co } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (co *CountOptions) SetMaxTime(d time.Duration) *CountOptions { - co.MaxTime = &d - return co -} - // SetSkip sets the value for the Skip field. func (co *CountOptions) SetSkip(i int64) *CountOptions { co.Skip = &i diff --git a/mongo/options/distinctoptions.go b/mongo/options/distinctoptions.go index 4cfcb98526..33efd58006 100644 --- a/mongo/options/distinctoptions.go +++ b/mongo/options/distinctoptions.go @@ -6,8 +6,6 @@ package options -import "time" - // DistinctOptions represents options that can be used to configure a Distinct operation. type DistinctOptions struct { // Specifies a collation to use for string comparisons during the operation. This option is only valid for MongoDB @@ -18,14 +16,6 @@ type DistinctOptions struct { // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be - // used in its place to control the amount of time that a single operation can run before returning an error. - // MaxTime is ignored if Timeout is set on the client. - MaxTime *time.Duration } // Distinct creates a new DistinctOptions instance. @@ -44,13 +34,3 @@ func (do *DistinctOptions) SetComment(comment interface{}) *DistinctOptions { do.Comment = comment return do } - -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (do *DistinctOptions) SetMaxTime(d time.Duration) *DistinctOptions { - do.MaxTime = &d - return do -} diff --git a/mongo/options/estimatedcountoptions.go b/mongo/options/estimatedcountoptions.go index b7d52bef6d..5f32ab13ba 100644 --- a/mongo/options/estimatedcountoptions.go +++ b/mongo/options/estimatedcountoptions.go @@ -6,21 +6,11 @@ package options -import "time" - // EstimatedDocumentCountOptions represents options that can be used to configure an EstimatedDocumentCount operation. type EstimatedDocumentCountOptions struct { // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace // the operation. The default is nil, which means that no comment will be included in the logs. Comment interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // EstimatedDocumentCount creates a new EstimatedDocumentCountOptions instance. @@ -33,13 +23,3 @@ func (eco *EstimatedDocumentCountOptions) SetComment(comment interface{}) *Estim eco.Comment = comment return eco } - -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option -// may be used in its place to control the amount of time that a single operation can run before -// returning an error. MaxTime is ignored if Timeout is set on the client. -func (eco *EstimatedDocumentCountOptions) SetMaxTime(d time.Duration) *EstimatedDocumentCountOptions { - eco.MaxTime = &d - return eco -} diff --git a/mongo/options/findoptions.go b/mongo/options/findoptions.go index 705fefc3f3..e8c8fa4c60 100644 --- a/mongo/options/findoptions.go +++ b/mongo/options/findoptions.go @@ -58,14 +58,6 @@ type FindOptions struct { // MongoDB versions >= 3.2. For other cursor types or previous server versions, this option is ignored. MaxAwaitTime *time.Duration - // MaxTime is the maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used in its - // place to control the amount of time that a single operation can run before returning an error. MaxTime is ignored if - // Timeout is set on the client. - MaxTime *time.Duration - // Min is a document specifying the inclusive lower bound for a specific index. The default value is 0, which means that // there is no minimum value. Min interface{} @@ -171,16 +163,6 @@ func (f *FindOptions) SetMaxAwaitTime(d time.Duration) *FindOptions { return f } -// SetMaxTime specifies the max time to allow the query to run. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used used in its place to control the amount of time that a single operation -// can run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOptions) SetMaxTime(d time.Duration) *FindOptions { - f.MaxTime = &d - return f -} - // SetMin sets the value for the Min field. func (f *FindOptions) SetMin(min interface{}) *FindOptions { f.Min = min @@ -248,14 +230,6 @@ type FindOneOptions struct { // there is no maximum value. Max interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document specifying the inclusive lower bound for a specific index. The default value is 0, which means that // there is no minimum value. Min interface{} @@ -315,16 +289,6 @@ func (f *FindOneOptions) SetMax(max interface{}) *FindOneOptions { return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneOptions) SetMaxTime(d time.Duration) *FindOneOptions { - f.MaxTime = &d - return f -} - // SetMin sets the value for the Min field. func (f *FindOneOptions) SetMin(min interface{}) *FindOneOptions { f.Min = min @@ -378,14 +342,6 @@ type FindOneAndReplaceOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -441,16 +397,6 @@ func (f *FindOneAndReplaceOptions) SetComment(comment interface{}) *FindOneAndRe return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndReplaceOptions) SetMaxTime(d time.Duration) *FindOneAndReplaceOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndReplaceOptions) SetProjection(projection interface{}) *FindOneAndReplaceOptions { f.Projection = projection @@ -509,14 +455,6 @@ type FindOneAndUpdateOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime is - // ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -578,16 +516,6 @@ func (f *FindOneAndUpdateOptions) SetComment(comment interface{}) *FindOneAndUpd return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndUpdateOptions) SetMaxTime(d time.Duration) *FindOneAndUpdateOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndUpdateOptions) SetProjection(projection interface{}) *FindOneAndUpdateOptions { f.Projection = projection @@ -635,14 +563,6 @@ type FindOneAndDeleteOptions struct { // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // A document describing which fields will be included in the document returned by the operation. The default value // is nil, which means all fields will be included. Projection interface{} @@ -684,16 +604,6 @@ func (f *FindOneAndDeleteOptions) SetComment(comment interface{}) *FindOneAndDel return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *FindOneAndDeleteOptions) SetMaxTime(d time.Duration) *FindOneAndDeleteOptions { - f.MaxTime = &d - return f -} - // SetProjection sets the value for the Projection field. func (f *FindOneAndDeleteOptions) SetProjection(projection interface{}) *FindOneAndDeleteOptions { f.Projection = projection diff --git a/mongo/options/gridfsoptions.go b/mongo/options/gridfsoptions.go index 10d454c89d..c8dcf447fc 100644 --- a/mongo/options/gridfsoptions.go +++ b/mongo/options/gridfsoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -155,14 +153,6 @@ type GridFSFindOptions struct { // batch. The default value is 0. Limit *int32 - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration - // If true, the cursor created by the operation will not timeout after a period of inactivity. The default value // is false. NoCursorTimeout *bool @@ -198,16 +188,6 @@ func (f *GridFSFindOptions) SetLimit(i int32) *GridFSFindOptions { return f } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (f *GridFSFindOptions) SetMaxTime(d time.Duration) *GridFSFindOptions { - f.MaxTime = &d - return f -} - // SetNoCursorTimeout sets the value for the NoCursorTimeout field. func (f *GridFSFindOptions) SetNoCursorTimeout(b bool) *GridFSFindOptions { f.NoCursorTimeout = &b diff --git a/mongo/options/indexoptions.go b/mongo/options/indexoptions.go index 1837b1037a..82675d6a6a 100644 --- a/mongo/options/indexoptions.go +++ b/mongo/options/indexoptions.go @@ -6,10 +6,6 @@ package options -import ( - "time" -) - // CreateIndexesOptions represents options that can be used to configure IndexView.CreateOne and IndexView.CreateMany // operations. type CreateIndexesOptions struct { @@ -26,14 +22,6 @@ type CreateIndexesOptions struct { // is specified for MongoDB versions <= 4.2. The default value is nil, meaning that the server-side default will be // used. See dochub.mongodb.org/core/index-commit-quorum for more information. CommitQuorum interface{} - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // CreateIndexes creates a new CreateIndexesOptions instance. @@ -41,16 +29,6 @@ func CreateIndexes() *CreateIndexesOptions { return &CreateIndexesOptions{} } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (c *CreateIndexesOptions) SetMaxTime(d time.Duration) *CreateIndexesOptions { - c.MaxTime = &d - return c -} - // SetCommitQuorumInt sets the value for the CommitQuorum field as an int32. func (c *CreateIndexesOptions) SetCommitQuorumInt(quorum int32) *CreateIndexesOptions { c.CommitQuorum = quorum @@ -77,43 +55,17 @@ func (c *CreateIndexesOptions) SetCommitQuorumVotingMembers() *CreateIndexesOpti // DropIndexesOptions represents options that can be used to configure IndexView.DropOne and IndexView.DropAll // operations. -type DropIndexesOptions struct { - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration -} +type DropIndexesOptions struct{} // DropIndexes creates a new DropIndexesOptions instance. func DropIndexes() *DropIndexesOptions { return &DropIndexesOptions{} } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (d *DropIndexesOptions) SetMaxTime(duration time.Duration) *DropIndexesOptions { - d.MaxTime = &duration - return d -} - // ListIndexesOptions represents options that can be used to configure an IndexView.List operation. type ListIndexesOptions struct { // The maximum number of documents to be included in each batch returned by the server. BatchSize *int32 - - // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there - // is no time limit for query execution. - // - // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used - // in its place to control the amount of time that a single operation can run before returning an error. MaxTime - // is ignored if Timeout is set on the client. - MaxTime *time.Duration } // ListIndexes creates a new ListIndexesOptions instance. @@ -127,16 +79,6 @@ func (l *ListIndexesOptions) SetBatchSize(i int32) *ListIndexesOptions { return l } -// SetMaxTime sets the value for the MaxTime field. -// -// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can -// run before returning an error. MaxTime is ignored if Timeout is set on the client. -func (l *ListIndexesOptions) SetMaxTime(d time.Duration) *ListIndexesOptions { - l.MaxTime = &d - return l -} - // IndexOptions represents options that can be used to configure a new index created through the IndexView.CreateOne // or IndexView.CreateMany operations. type IndexOptions struct { diff --git a/mongo/options/sessionoptions.go b/mongo/options/sessionoptions.go index 4e1fdb1114..d83610b173 100644 --- a/mongo/options/sessionoptions.go +++ b/mongo/options/sessionoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -36,14 +34,6 @@ type SessionOptions struct { // the write concern of the client used to start the session will be used. DefaultWriteConcern *writeconcern.WriteConcern - // The default maximum amount of time that a CommitTransaction operation executed in the session can run on the - // server. The default value is nil, which means that that there is no time limit for execution. - // - // NOTE(benjirewis): DefaultMaxCommitTime will be deprecated in a future release. The more general Timeout option - // may be used in its place to control the amount of time that a single operation can run before returning an - // error. DefaultMaxCommitTime is ignored if Timeout is set on the client. - DefaultMaxCommitTime *time.Duration - // If true, all read operations performed with this session will be read from the same snapshot. This option cannot // be set to true if CausalConsistency is set to true. Transactions and write operations are not allowed on // snapshot sessions and will error. The default value is false. @@ -79,17 +69,6 @@ func (s *SessionOptions) SetDefaultWriteConcern(wc *writeconcern.WriteConcern) * return s } -// SetDefaultMaxCommitTime sets the value for the DefaultMaxCommitTime field. -// -// NOTE(benjirewis): DefaultMaxCommitTime will be deprecated in a future release. The more -// general Timeout option may be used in its place to control the amount of time that a -// single operation can run before returning an error. DefaultMaxCommitTime is ignored if -// Timeout is set on the client. -func (s *SessionOptions) SetDefaultMaxCommitTime(mct *time.Duration) *SessionOptions { - s.DefaultMaxCommitTime = mct - return s -} - // SetSnapshot sets the value for the Snapshot field. func (s *SessionOptions) SetSnapshot(b bool) *SessionOptions { s.Snapshot = &b diff --git a/mongo/options/transactionoptions.go b/mongo/options/transactionoptions.go index 2bc4c2166c..c346b0f63f 100644 --- a/mongo/options/transactionoptions.go +++ b/mongo/options/transactionoptions.go @@ -7,8 +7,6 @@ package options import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -27,18 +25,6 @@ type TransactionOptions struct { // The write concern for operations in the transaction. The default value is nil, which means that the default // write concern of the session used to start the transaction will be used. WriteConcern *writeconcern.WriteConcern - - // The default maximum amount of time that a CommitTransaction operation executed in the session can run on the - // server. The default value is nil, meaning that there is no time limit for execution. - - // The maximum amount of time that a CommitTransaction operation can executed in the transaction can run on the - // server. The default value is nil, which means that the default maximum commit time of the session used to - // start the transaction will be used. - // - // NOTE(benjirewis): MaxCommitTime will be deprecated in a future release. The more general Timeout option may - // be used in its place to control the amount of time that a single operation can run before returning an error. - // MaxCommitTime is ignored if Timeout is set on the client. - MaxCommitTime *time.Duration } // Transaction creates a new TransactionOptions instance. @@ -63,13 +49,3 @@ func (t *TransactionOptions) SetWriteConcern(wc *writeconcern.WriteConcern) *Tra t.WriteConcern = wc return t } - -// SetMaxCommitTime sets the value for the MaxCommitTime field. -// -// NOTE(benjirewis): MaxCommitTime will be deprecated in a future release. The more general Timeout -// option may be used in its place to control the amount of time that a single operation can run before -// returning an error. MaxCommitTime is ignored if Timeout is set on the client. -func (t *TransactionOptions) SetMaxCommitTime(mct *time.Duration) *TransactionOptions { - t.MaxCommitTime = mct - return t -} diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index ec49bb91db..ffa41039b1 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -13,7 +13,6 @@ import ( "path" "reflect" "testing" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" @@ -48,6 +47,7 @@ type connectionStringTest struct { Valid bool `bson:"valid"` ReadConcern bson.Raw `bson:"readConcern"` WriteConcern bson.Raw `bson:"writeConcern"` + SkipReason string `bson:"skipReason"` } type documentTestFile struct { @@ -98,6 +98,10 @@ func runConnectionStringTestFile(t *testing.T, filePath string) { } func runConnectionStringTest(t *testing.T, test connectionStringTest) { + if test.SkipReason != "" { + t.Skip(test.SkipReason) + } + cs, err := connstring.ParseAndValidate(test.URI) if !test.Valid { assert.NotNil(t, err, "expected Parse error, got nil") @@ -122,11 +126,6 @@ func runConnectionStringTest(t *testing.T, test connectionStringTest) { assert.Equal(t, expected, cs.WString, "expected w value %v, got %v", expected, cs.WString) } } - if expectedWc.timeoutSet { - assert.True(t, cs.WTimeoutSet, "expected WTimeoutSet, got false") - assert.Equal(t, expectedWc.WTimeout, cs.WTimeout, - "expected timeout value %v, got %v", expectedWc.WTimeout, cs.WTimeout) - } if expectedWc.jSet { assert.True(t, cs.JSet, "expected JSet, got false") assert.Equal(t, *expectedWc.Journal, cs.J, "expected j value %v, got %v", *expectedWc.Journal, cs.J) @@ -221,9 +220,8 @@ func readConcernFromRaw(t *testing.T, rc bson.Raw) *readconcern.ReadConcern { type writeConcern struct { *writeconcern.WriteConcern - jSet bool - wSet bool - timeoutSet bool + jSet bool + wSet bool } func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern { @@ -247,14 +245,12 @@ func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern { default: t.Fatalf("unexpected type for w: %v", val.Type) } - case "wtimeoutMS": - wc.timeoutSet = true - timeout := time.Duration(val.Int32()) * time.Millisecond - wc.WriteConcern.WTimeout = timeout case "journal": wc.jSet = true j := val.Boolean() wc.WriteConcern.Journal = &j + case "wtimeoutMS": // Do nothing, this field is deprecated + t.Skip("the wtimeoutMS write concern option is not supported") default: t.Fatalf("unrecognized write concern field: %v", key) } diff --git a/mongo/session.go b/mongo/session.go index 778abebc63..5df2d800f4 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -214,15 +214,11 @@ func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { if opt.WriteConcern != nil { topts.WriteConcern = opt.WriteConcern } - if opt.MaxCommitTime != nil { - topts.MaxCommitTime = opt.MaxCommitTime - } } coreOpts := &session.TransactionOptions{ ReadConcern: topts.ReadConcern, ReadPreference: topts.ReadPreference, WriteConcern: topts.WriteConcern, - MaxCommitTime: topts.MaxCommitTime, } return s.clientSession.StartTransaction(coreOpts) @@ -282,7 +278,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) + ServerAPI(s.client.serverAPI) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/mongo/writeconcern/writeconcern.go b/mongo/writeconcern/writeconcern.go index 2e4d2ade16..2309e01836 100644 --- a/mongo/writeconcern/writeconcern.go +++ b/mongo/writeconcern/writeconcern.go @@ -13,7 +13,6 @@ package writeconcern import ( "errors" "fmt" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -74,17 +73,6 @@ type WriteConcern struct { // For more information about the "j" option, see // https://www.mongodb.com/docs/manual/reference/write-concern/#j-option Journal *bool - - // WTimeout specifies a time limit for the write concern. It sets the - // "wtimeout" option in a MongoDB write concern. - // - // It is only applicable for "w" values greater than 1. Using a WTimeout and - // setting Timeout on the Client at the same time will result in undefined - // behavior. - // - // For more information about the "wtimeout" option, see - // https://www.mongodb.com/docs/manual/reference/write-concern/#wtimeout - WTimeout time.Duration } // Unacknowledged returns a WriteConcern that requests no acknowledgment of @@ -183,14 +171,6 @@ func (wc *WriteConcern) MarshalBSONValue() (bson.Type, []byte, error) { elems = bsoncore.AppendBooleanElement(elems, "j", *wc.Journal) } - if wc.WTimeout < 0 { - return 0, nil, ErrNegativeWTimeout - } - - if wc.WTimeout != 0 { - elems = bsoncore.AppendInt64Element(elems, "wtimeout", int64(wc.WTimeout/time.Millisecond)) - } - if len(elems) == 0 { return 0, nil, ErrEmptyWriteConcern } diff --git a/mongo/writeconcern/writeconcern_test.go b/mongo/writeconcern/writeconcern_test.go index b3486fe4f9..07f7b9c3ae 100644 --- a/mongo/writeconcern/writeconcern_test.go +++ b/mongo/writeconcern/writeconcern_test.go @@ -7,118 +7,12 @@ package writeconcern_test import ( - "errors" "testing" - "time" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/writeconcern" ) -func TestWriteConcern_MarshalBSONValue(t *testing.T) { - t.Parallel() - - boolPtr := func(b bool) *bool { return &b } - - testCases := []struct { - name string - wc *writeconcern.WriteConcern - wantType bson.Type - wantValue bson.D - wantError error - }{ - { - name: "all fields", - wc: &writeconcern.WriteConcern{ - W: "majority", - Journal: boolPtr(false), - WTimeout: 1 * time.Minute, - }, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{ - {Key: "w", Value: "majority"}, - {Key: "j", Value: false}, - {Key: "wtimeout", Value: int64(60_000)}, - }, - }, - { - name: "string W", - wc: &writeconcern.WriteConcern{W: "majority"}, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{{Key: "w", Value: "majority"}}, - }, - { - name: "int W", - wc: &writeconcern.WriteConcern{W: 1}, - wantType: bson.TypeEmbeddedDocument, - wantValue: bson.D{{Key: "w", Value: int32(1)}}, - }, - { - name: "int32 W", - wc: &writeconcern.WriteConcern{W: int32(1)}, - wantError: errors.New("WriteConcern.W must be a string or int, but is a int32"), - }, - { - name: "bool W", - wc: &writeconcern.WriteConcern{W: false}, - wantError: errors.New("WriteConcern.W must be a string or int, but is a bool"), - }, - { - name: "W=0 and J=true", - wc: &writeconcern.WriteConcern{W: 0, Journal: boolPtr(true)}, - wantError: writeconcern.ErrInconsistent, - }, - { - name: "negative W", - wc: &writeconcern.WriteConcern{W: -1}, - wantError: writeconcern.ErrNegativeW, - }, - { - name: "negative WTimeout", - wc: &writeconcern.WriteConcern{W: 1, WTimeout: -1}, - wantError: writeconcern.ErrNegativeWTimeout, - }, - { - name: "empty", - wc: &writeconcern.WriteConcern{}, - wantError: writeconcern.ErrEmptyWriteConcern, - }, - { - name: "nil", - wc: nil, - wantError: writeconcern.ErrEmptyWriteConcern, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - typ, b, err := tc.wc.MarshalBSONValue() - if tc.wantError != nil { - assert.Equal(t, tc.wantError, err, "expected and actual errors do not match") - return - } - require.NoError(t, err, "bson.MarshalValue error") - - assert.Equal(t, tc.wantType, typ, "expected and actual BSON types do not match") - - rv := bson.RawValue{ - Type: typ, - Value: b, - } - var gotValue bson.D - err = rv.Unmarshal(&gotValue) - require.NoError(t, err, "error unmarshaling RawValue") - assert.Equal(t, tc.wantValue, gotValue, "expected and actual BSON values do not match") - }) - } -} - func TestWriteConcern(t *testing.T) { boolPtr := func(b bool) *bool { return &b } diff --git a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json index 6cf1f4ce6e..aded781aee 100644 --- a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json +++ b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.json @@ -6,7 +6,7 @@ "minServerVersion": "4.4", "topologies": [ "replicaset", - "sharded-replicaset" + "sharded" ] } ], @@ -73,7 +73,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -132,7 +132,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -194,7 +194,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -255,7 +255,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -319,7 +319,7 @@ "delete" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -376,7 +376,7 @@ "delete" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -436,7 +436,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -496,7 +496,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -559,7 +559,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -621,7 +621,7 @@ "update" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -686,7 +686,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -743,7 +743,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -803,7 +803,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -863,7 +863,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -926,7 +926,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -988,7 +988,7 @@ "findAndModify" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1053,7 +1053,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1118,7 +1118,7 @@ "insert" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1186,7 +1186,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1243,7 +1243,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1303,7 +1303,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1357,7 +1357,7 @@ "listDatabases" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1414,7 +1414,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1471,7 +1471,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1531,7 +1531,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1595,7 +1595,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1662,7 +1662,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1719,7 +1719,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1779,7 +1779,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1836,7 +1836,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1896,7 +1896,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -1953,7 +1953,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2013,7 +2013,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2070,7 +2070,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2130,7 +2130,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2187,7 +2187,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2247,7 +2247,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2304,7 +2304,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2364,7 +2364,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2418,7 +2418,7 @@ "count" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2475,7 +2475,7 @@ "distinct" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2533,7 +2533,7 @@ "distinct" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2594,7 +2594,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2651,7 +2651,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2711,7 +2711,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2768,7 +2768,7 @@ "find" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2828,7 +2828,7 @@ "listIndexes" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2882,7 +2882,7 @@ "listIndexes" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2939,7 +2939,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } @@ -2996,7 +2996,7 @@ "aggregate" ], "blockConnection": true, - "blockTimeMS": 110 + "blockTimeMS": 125 } } } diff --git a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml index de3eb9971d..8ada5fb791 100644 --- a/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml +++ b/testdata/client-side-operations-timeout/retryability-legacy-timeouts.yml @@ -6,7 +6,7 @@ schemaVersion: "1.9" runOnRequirements: - minServerVersion: "4.4" - topologies: ["replicaset", "sharded-replicaset"] + topologies: ["replicaset", "sharded"] createEntities: - client: @@ -38,8 +38,8 @@ initialData: tests: # For each retryable operation, run two tests: # - # 1. Socket timeouts are retried once - Each test constructs a client entity with socketTimeoutMS=50, configures a - # fail point to block the operation once for 110ms, and expects the operation to succeed. + # 1. Socket timeouts are retried once - Each test constructs a client entity with socketTimeoutMS=100, configures a + # fail point to block the operation once for 125ms, and expects the operation to succeed. # # 2. Operations fail after two consecutive socket timeouts - Same as (1) but the fail point is configured to block # the operation twice and the test expects the operation to fail. @@ -56,7 +56,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertOne object: *collection arguments: @@ -87,7 +87,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertOne object: *collection arguments: @@ -121,7 +121,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertMany object: *collection arguments: @@ -153,7 +153,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: insertMany object: *collection arguments: @@ -188,7 +188,7 @@ tests: data: failCommands: ["delete"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: deleteOne object: *collection arguments: @@ -219,7 +219,7 @@ tests: data: failCommands: ["delete"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: deleteOne object: *collection arguments: @@ -253,7 +253,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: replaceOne object: *collection arguments: @@ -285,7 +285,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: replaceOne object: *collection arguments: @@ -320,7 +320,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: updateOne object: *collection arguments: @@ -352,7 +352,7 @@ tests: data: failCommands: ["update"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: updateOne object: *collection arguments: @@ -387,7 +387,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndDelete object: *collection arguments: @@ -418,7 +418,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndDelete object: *collection arguments: @@ -452,7 +452,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndReplace object: *collection arguments: @@ -484,7 +484,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndReplace object: *collection arguments: @@ -519,7 +519,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndUpdate object: *collection arguments: @@ -551,7 +551,7 @@ tests: data: failCommands: ["findAndModify"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOneAndUpdate object: *collection arguments: @@ -586,7 +586,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: bulkWrite object: *collection arguments: @@ -619,7 +619,7 @@ tests: data: failCommands: ["insert"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: bulkWrite object: *collection arguments: @@ -655,7 +655,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabases object: *client arguments: @@ -686,7 +686,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabases object: *client arguments: @@ -720,7 +720,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabaseNames object: *client @@ -749,7 +749,7 @@ tests: data: failCommands: ["listDatabases"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listDatabaseNames object: *client @@ -781,7 +781,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *client arguments: @@ -812,7 +812,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *client arguments: @@ -846,7 +846,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *database arguments: @@ -877,7 +877,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *database arguments: @@ -911,7 +911,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollections object: *database arguments: @@ -942,7 +942,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollections object: *database arguments: @@ -976,7 +976,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollectionNames object: *database arguments: @@ -1007,7 +1007,7 @@ tests: data: failCommands: ["listCollections"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listCollectionNames object: *database arguments: @@ -1041,7 +1041,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *database arguments: @@ -1072,7 +1072,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *database arguments: @@ -1106,7 +1106,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *collection arguments: @@ -1137,7 +1137,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: aggregate object: *collection arguments: @@ -1171,7 +1171,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: count object: *collection arguments: @@ -1202,7 +1202,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: count object: *collection arguments: @@ -1236,7 +1236,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: countDocuments object: *collection arguments: @@ -1267,7 +1267,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: countDocuments object: *collection arguments: @@ -1301,7 +1301,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: estimatedDocumentCount object: *collection @@ -1330,7 +1330,7 @@ tests: data: failCommands: ["count"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: estimatedDocumentCount object: *collection @@ -1362,7 +1362,7 @@ tests: data: failCommands: ["distinct"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: distinct object: *collection arguments: @@ -1394,7 +1394,7 @@ tests: data: failCommands: ["distinct"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: distinct object: *collection arguments: @@ -1429,7 +1429,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: find object: *collection arguments: @@ -1460,7 +1460,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: find object: *collection arguments: @@ -1494,7 +1494,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOne object: *collection arguments: @@ -1525,7 +1525,7 @@ tests: data: failCommands: ["find"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: findOne object: *collection arguments: @@ -1559,7 +1559,7 @@ tests: data: failCommands: ["listIndexes"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listIndexes object: *collection @@ -1588,7 +1588,7 @@ tests: data: failCommands: ["listIndexes"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: listIndexes object: *collection @@ -1620,7 +1620,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *collection arguments: @@ -1651,7 +1651,7 @@ tests: data: failCommands: ["aggregate"] blockConnection: true - blockTimeMS: 110 + blockTimeMS: 125 - name: createChangeStream object: *collection arguments: diff --git a/testdata/convenient-transactions/commit-retry.json b/testdata/convenient-transactions/commit-retry.json index 02e38460d0..6257e99345 100644 --- a/testdata/convenient-transactions/commit-retry.json +++ b/testdata/convenient-transactions/commit-retry.json @@ -150,6 +150,7 @@ }, { "description": "commitTransaction retry only overwrites write concern w option", + "skipReason": "GODRIVER-2348: wtimeout is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -429,6 +430,7 @@ }, { "description": "commit is not retried after MaxTimeMSExpired error", + "skipReason": "GODRIVER-2348: maxTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { diff --git a/testdata/convenient-transactions/commit-retry.yml b/testdata/convenient-transactions/commit-retry.yml index 74c03dd9fb..3ff6497ae4 100644 --- a/testdata/convenient-transactions/commit-retry.yml +++ b/testdata/convenient-transactions/commit-retry.yml @@ -99,6 +99,7 @@ tests: - { _id: 1 } - description: commitTransaction retry only overwrites write concern w option + skipReason: "GODRIVER-2348: wtimeout is deprecated" failPoint: configureFailPoint: failCommand mode: { times: 2 } @@ -260,6 +261,7 @@ tests: - { _id: 1 } - description: commit is not retried after MaxTimeMSExpired error + skipReason: "GODRIVER-2348: maxTimeMS is deprecated" failPoint: configureFailPoint: failCommand mode: { times: 1 } diff --git a/testdata/read-write-concern/connection-string/write-concern.json b/testdata/read-write-concern/connection-string/write-concern.json index 51bdf821c3..a81e297dae 100644 --- a/testdata/read-write-concern/connection-string/write-concern.json +++ b/testdata/read-write-concern/connection-string/write-concern.json @@ -33,6 +33,7 @@ }, { "description": "wtimeoutMS as a valid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?wtimeoutMS=500", "valid": true, "warning": false, @@ -42,6 +43,7 @@ }, { "description": "wtimeoutMS as an invalid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?wtimeoutMS=-500", "valid": false, "warning": null @@ -66,6 +68,7 @@ }, { "description": "All options combined", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?w=3&wtimeoutMS=500&journal=true", "valid": true, "warning": false, @@ -96,6 +99,7 @@ }, { "description": "Unacknowledged with w and wtimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "uri": "mongodb://localhost/?w=0&wtimeoutMS=500", "valid": true, "warning": false, diff --git a/testdata/read-write-concern/connection-string/write-concern.yml b/testdata/read-write-concern/connection-string/write-concern.yml index ca61085865..52e09170e8 100644 --- a/testdata/read-write-concern/connection-string/write-concern.yml +++ b/testdata/read-write-concern/connection-string/write-concern.yml @@ -24,12 +24,14 @@ tests: writeConcern: { w: "majority" } - description: "wtimeoutMS as a valid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?wtimeoutMS=500" valid: true warning: false writeConcern: { wtimeoutMS: 500 } - description: "wtimeoutMS as an invalid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?wtimeoutMS=-500" valid: false warning: ~ @@ -47,6 +49,7 @@ tests: writeConcern: { journal: true } - description: "All options combined" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?w=3&wtimeoutMS=500&journal=true" valid: true warning: false @@ -65,6 +68,7 @@ tests: writeConcern: { w: 0, journal: false } - description: "Unacknowledged with w and wtimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" uri: "mongodb://localhost/?w=0&wtimeoutMS=500" valid: true warning: false diff --git a/testdata/read-write-concern/document/write-concern.json b/testdata/read-write-concern/document/write-concern.json index 64cd5d0eae..fe81741e70 100644 --- a/testdata/read-write-concern/document/write-concern.json +++ b/testdata/read-write-concern/document/write-concern.json @@ -56,6 +56,7 @@ }, { "description": "WTimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "wtimeoutMS": 1000 @@ -68,6 +69,7 @@ }, { "description": "WTimeoutMS as an invalid number", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": false, "writeConcern": { "wtimeoutMS": -1000 @@ -114,6 +116,7 @@ }, { "description": "Unacknowledged with wtimeoutMS", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "w": 0, @@ -156,6 +159,7 @@ }, { "description": "Everything", + "skipReason": "GODRIVER-2348: the wtimeoutMS write concern option is not supported", "valid": true, "writeConcern": { "w": 3, diff --git a/testdata/read-write-concern/document/write-concern.yml b/testdata/read-write-concern/document/write-concern.yml index bd82fdd59d..0c31f6958b 100644 --- a/testdata/read-write-concern/document/write-concern.yml +++ b/testdata/read-write-concern/document/write-concern.yml @@ -36,6 +36,7 @@ tests: isAcknowledged: true - description: "WTimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: true writeConcern: { wtimeoutMS: 1000 } writeConcernDocument: { wtimeout: 1000 } @@ -43,6 +44,7 @@ tests: isAcknowledged: true - description: "WTimeoutMS as an invalid number" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: false writeConcern: { wtimeoutMS: -1000 } writeConcernDocument: ~ @@ -71,6 +73,7 @@ tests: isAcknowledged: false - description: "Unacknowledged with wtimeoutMS" + skipReason: "GODRIVER-2348: the wtimeoutMS write concern option is not supported" valid: true writeConcern: { w: 0, wtimeoutMS: 500 } writeConcernDocument: { w: 0, wtimeout: 500 } diff --git a/testdata/transactions/legacy/error-labels.json b/testdata/transactions/legacy/error-labels.json index a57f216b9b..8bb5af7700 100644 --- a/testdata/transactions/legacy/error-labels.json +++ b/testdata/transactions/legacy/error-labels.json @@ -1687,6 +1687,7 @@ }, { "description": "do not add UnknownTransactionCommitResult label to MaxTimeMSExpired inside transactions", + "skipReason": "GODRIVER-2348: maxTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -1817,6 +1818,7 @@ }, { "description": "add UnknownTransactionCommitResult label to MaxTimeMSExpired", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { @@ -1949,6 +1951,7 @@ }, { "description": "add UnknownTransactionCommitResult label to writeConcernError MaxTimeMSExpired", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "failPoint": { "configureFailPoint": "failCommand", "mode": { diff --git a/testdata/transactions/legacy/error-labels.yml b/testdata/transactions/legacy/error-labels.yml index 5f2c7085c1..d9c461eadf 100644 --- a/testdata/transactions/legacy/error-labels.yml +++ b/testdata/transactions/legacy/error-labels.yml @@ -1029,6 +1029,7 @@ tests: - _id: 1 - description: do not add UnknownTransactionCommitResult label to MaxTimeMSExpired inside transactions + skipReason: "GODRIVER-2348: maxTimeMS is deprecated" failPoint: configureFailPoint: failCommand @@ -1109,6 +1110,7 @@ tests: data: [] - description: add UnknownTransactionCommitResult label to MaxTimeMSExpired + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" failPoint: configureFailPoint: failCommand @@ -1190,6 +1192,7 @@ tests: - _id: 1 - description: add UnknownTransactionCommitResult label to writeConcernError MaxTimeMSExpired + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" failPoint: configureFailPoint: failCommand diff --git a/testdata/transactions/legacy/retryable-commit.json b/testdata/transactions/legacy/retryable-commit.json index d83a1d9f52..dde1714603 100644 --- a/testdata/transactions/legacy/retryable-commit.json +++ b/testdata/transactions/legacy/retryable-commit.json @@ -161,6 +161,7 @@ }, { "description": "commitTransaction applies majority write concern on retries", + "skipReason": "GODRIVER-2348: wtimeout is deprecated", "clientOptions": { "retryWrites": false }, diff --git a/testdata/transactions/legacy/retryable-commit.yml b/testdata/transactions/legacy/retryable-commit.yml index 8e0037f28e..f48b53f1b9 100644 --- a/testdata/transactions/legacy/retryable-commit.yml +++ b/testdata/transactions/legacy/retryable-commit.yml @@ -102,6 +102,7 @@ tests: - _id: 1 - description: commitTransaction applies majority write concern on retries + skipReason: "GODRIVER-2348: wtimeout is deprecated" clientOptions: retryWrites: false diff --git a/testdata/transactions/legacy/transaction-options.json b/testdata/transactions/legacy/transaction-options.json index 25d245dca5..d474e3773d 100644 --- a/testdata/transactions/legacy/transaction-options.json +++ b/testdata/transactions/legacy/transaction-options.json @@ -318,6 +318,7 @@ }, { "description": "transaction options inherited from defaultTransactionOptions", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "sessionOptions": { "session0": { "defaultTransactionOptions": { @@ -479,6 +480,7 @@ }, { "description": "startTransaction options override defaults", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "clientOptions": { "readConcernLevel": "local", "w": 1 @@ -668,6 +670,7 @@ }, { "description": "defaultTransactionOptions override client options", + "skipReason": "GODRIVER-2348: maxCommitTimeMS is deprecated", "clientOptions": { "readConcernLevel": "local", "w": 1 diff --git a/testdata/transactions/legacy/transaction-options.yml b/testdata/transactions/legacy/transaction-options.yml index 461e87d55f..314e0284a6 100644 --- a/testdata/transactions/legacy/transaction-options.yml +++ b/testdata/transactions/legacy/transaction-options.yml @@ -260,6 +260,7 @@ tests: outcome: *outcome - description: startTransaction options override defaults + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readConcernLevel: local @@ -381,6 +382,7 @@ tests: outcome: *outcome - description: defaultTransactionOptions override client options + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readConcernLevel: local @@ -665,6 +667,7 @@ tests: - _id: 1 - description: readPreference inherited from defaultTransactionOptions + skipReason: "GODRIVER-2348: maxCommitTimeMS is deprecated" clientOptions: readPreference: primary diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index f78ef652fe..6716016924 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -45,7 +45,6 @@ type BatchCursor struct { errorProcessor ErrorProcessor // This will only be set when pinning to a connection. connection *mnet.Connection batchSize int32 - maxTimeMS int64 currentBatch *bsoncore.Iterator firstBatch bool cmdMonitor *event.CommandMonitor @@ -53,6 +52,10 @@ type BatchCursor struct { crypt Crypt serverAPI *ServerAPIOptions + // maxAwaitTime is only valid for tailable awaitData cursors. If this option + // is set, it will be used as the "maxTimeMS" field on getMore commands. + maxAwaitTime *time.Duration + // legacy server (< 3.2) fields limit int32 numReturned int32 // number of docs returned by server @@ -157,12 +160,21 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { type CursorOptions struct { BatchSize int32 Comment bsoncore.Value - MaxTimeMS int64 Limit int32 CommandMonitor *event.CommandMonitor Crypt Crypt ServerAPI *ServerAPIOptions MarshalValueEncoderFn func(io.Writer) (*bson.Encoder, error) + + // MaxAwaitTime is only valid for tailable awaitData cursors. If this option + // is set, it will be used as the "maxTimeMS" field on getMore commands. + MaxAwaitTime *time.Duration +} + +// SetMaxAwaitTime will set the maxTimeMS value on getMore commands for +// tailable awaitData cursors. +func (cursorOptions *CursorOptions) SetMaxAwaitTime(dur time.Duration) { + cursorOptions.MaxAwaitTime = &dur } // NewBatchCursor creates a new BatchCursor from the provided parameters. @@ -185,7 +197,7 @@ func NewBatchCursor( connection: cr.Connection, errorProcessor: cr.ErrorProcessor, batchSize: opts.BatchSize, - maxTimeMS: opts.MaxTimeMS, + maxAwaitTime: opts.MaxAwaitTime, cmdMonitor: opts.CommandMonitor, firstBatch: true, postBatchResumeToken: cr.postBatchResumeToken, @@ -363,14 +375,15 @@ func (bc *BatchCursor) getMore(ctx context.Context) { } bc.err = Operation{ - CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id) dst = bsoncore.AppendStringElement(dst, "collection", bc.collection) if numToReturn > 0 { dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn) } - if bc.maxTimeMS > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", bc.maxTimeMS) + + if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond)) } comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn) @@ -471,14 +484,14 @@ func (bc *BatchCursor) SetBatchSize(size int32) { bc.batchSize = size } -// SetMaxTime will set the maximum amount of time the server will allow the +// SetMaxAwaitTime will set the maximum amount of time the server will allow the // operations to execute. The server will error if this field is set but the // cursor is not configured with awaitData=true. // // The time.Duration value passed by this setter will be converted and rounded // down to the nearest millisecond. -func (bc *BatchCursor) SetMaxTime(dur time.Duration) { - bc.maxTimeMS = int64(dur / time.Millisecond) +func (bc *BatchCursor) SetMaxAwaitTime(dur time.Duration) { + bc.maxAwaitTime = &dur } // SetComment sets the comment for future getMore operations. @@ -509,7 +522,7 @@ var _ Deployment = (*loadBalancedCursorDeployment)(nil) var _ Server = (*loadBalancedCursorDeployment)(nil) var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil) -func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) { +func (lbcd *loadBalancedCursorDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) { return lbcd, nil } @@ -529,3 +542,9 @@ func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor { func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, desc mnet.Describer) ProcessErrorResult { return lbcd.errorProcessor.ProcessError(err, desc) } + +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for load-balanced cursor deployments. +func (*loadBalancedCursorDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} diff --git a/x/mongo/driver/batch_cursor_test.go b/x/mongo/driver/batch_cursor_test.go index c57434cb83..7c9ad38c7b 100644 --- a/x/mongo/driver/batch_cursor_test.go +++ b/x/mongo/driver/batch_cursor_test.go @@ -8,7 +8,6 @@ package driver import ( "testing" - "time" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -91,43 +90,3 @@ func TestBatchCursor(t *testing.T) { } }) } - -func TestBatchCursorSetMaxTime(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - dur time.Duration - want int64 - }{ - { - name: "empty", - dur: 0, - want: 0, - }, - { - name: "partial milliseconds are truncated", - dur: 10_900 * time.Microsecond, - want: 10, - }, - { - name: "millisecond input", - dur: 10 * time.Millisecond, - want: 10, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - bc := BatchCursor{} - bc.SetMaxTime(test.dur) - - got := bc.maxTimeMS - assert.Equal(t, test.want, got, "expected and actual maxTimeMS are different") - }) - } -} diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 43fae2fb1a..4465afe3be 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -187,10 +187,6 @@ type ConnString struct { ZstdLevel int ZstdLevelSet bool - WTimeout time.Duration - WTimeoutSet bool - WTimeoutSetFromOption bool - Options map[string][]string UnknownOptions map[string][]string } @@ -650,24 +646,6 @@ func (u *ConnString) addOptions(connectionArgPairs []string) error { u.WString = value u.WNumberSet = false - - case "wtimeoutms": - n, err := strconv.Atoi(value) - if err != nil || n < 0 { - return fmt.Errorf("invalid value for %q: %q", key, value) - } - u.WTimeout = time.Duration(n) * time.Millisecond - u.WTimeoutSet = true - case "wtimeout": - // Defer to wtimeoutms, but not to a manually-set option. - if u.WTimeoutSet { - break - } - n, err := strconv.Atoi(value) - if err != nil || n < 0 { - return fmt.Errorf("invalid value for %q: %q", key, value) - } - u.WTimeout = time.Duration(n) * time.Millisecond case "zlibcompressionlevel": level, err := strconv.Atoi(value) if err != nil || (level < -1 || level > 9) { @@ -1032,11 +1010,6 @@ func (p *parser) parse(original string) (*ConnString, error) { return nil, err } - // If WTimeout was set from manual options passed in, set WTImeoutSet to true. - if connStr.WTimeoutSetFromOption { - connStr.WTimeoutSet = true - } - return connStr, nil } diff --git a/x/mongo/driver/connstring/connstring_spec_test.go b/x/mongo/driver/connstring/connstring_spec_test.go index aea68eba71..af7b25f385 100644 --- a/x/mongo/driver/connstring/connstring_spec_test.go +++ b/x/mongo/driver/connstring/connstring_spec_test.go @@ -99,6 +99,8 @@ func runTestsInFile(t *testing.T, dirname string, filename string, warningsError var skipDescriptions = map[string]struct{}{ "Valid options specific to single-threaded drivers are parsed correctly": {}, + // GODRIVER-2348: the wtimeoutMS write concern option is not supported. + "Valid read and write concern are parsed correctly": {}, } var skipKeywords = []string{ @@ -106,6 +108,9 @@ var skipKeywords = []string{ "tlsAllowInvalidCertificates", "tlsDisableCertificateRevocationCheck", "serverSelectionTryOnce", + + // GODRIVER-2348: the wtimeoutMS write concern option is not supported. + "wTimeoutMS", } func runTest(t *testing.T, filename string, test testCase, warningsError bool) { @@ -277,8 +282,6 @@ func verifyConnStringOptions(t *testing.T, cs *connstring.ConnString, options ma } else { require.Equal(t, value, cs.WString) } - case "wtimeoutms": - require.Equal(t, value, float64(cs.WTimeout/time.Millisecond)) case "waitqueuetimeoutms": case "zlibcompressionlevel": require.Equal(t, value, float64(cs.ZlibLevel)) diff --git a/x/mongo/driver/connstring/connstring_test.go b/x/mongo/driver/connstring/connstring_test.go index 84c8ff1d45..001cd72fe5 100644 --- a/x/mongo/driver/connstring/connstring_test.go +++ b/x/mongo/driver/connstring/connstring_test.go @@ -564,33 +564,6 @@ func TestSocketTimeout(t *testing.T) { } } -func TestWTimeout(t *testing.T) { - tests := []struct { - s string - expected time.Duration - err bool - }{ - {s: "wtimeoutMS=10", expected: time.Duration(10) * time.Millisecond}, - {s: "wtimeoutMS=100", expected: time.Duration(100) * time.Millisecond}, - {s: "wtimeoutMS=-2", err: true}, - {s: "wtimeoutMS=gsdge", err: true}, - } - - for _, test := range tests { - s := fmt.Sprintf("mongodb://localhost/?%s", test.s) - t.Run(s, func(t *testing.T) { - cs, err := connstring.ParseAndValidate(s) - if test.err { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, test.expected, cs.WTimeout) - require.True(t, cs.WTimeoutSet) - } - }) - } -} - func TestCompressionOptions(t *testing.T) { tests := []struct { name string diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 16992b4099..b6a95e32da 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -28,6 +28,12 @@ import ( type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) Kind() description.TopologyKind + + // GetServerSelectionTimeout returns a timeout that should be used to set a + // deadline for server selection. This logic is not handleded internally by + // the ServerSelector, as a resulting deadline may be applicable by follow-up + // operations such as checking out a connection. + GetServerSelectionTimeout() time.Duration } // Connector represents a type that can connect to a server. @@ -144,6 +150,12 @@ func (ssd SingleServerDeployment) SelectServer(context.Context, description.Serv // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (SingleServerDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for single server deployments. +func (SingleServerDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This // implementation should only be used for connection handshakes and server heartbeats as it does not implement // ErrorProcessor, which is necessary for application operations. @@ -159,6 +171,12 @@ func (scd SingleConnectionDeployment) SelectServer(context.Context, description. return scd, nil } +// GetServerSelectionTimeout returns zero as a server selection timeout is not +// applicable for single connection deployment. +func (SingleConnectionDeployment) GetServerSelectionTimeout() time.Duration { + return 0 +} + // Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (SingleConnectionDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 3a189318cb..b12ac5d396 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -59,8 +59,6 @@ var ( ErrDeadlineWouldBeExceeded = fmt.Errorf( "operation not sent to server, as Timeout would be exceeded: %w", context.DeadlineExceeded) - // ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value. - ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation") ) // QueryFailureError is an error representing a command failure as a document. diff --git a/x/mongo/driver/integration/aggregate_test.go b/x/mongo/driver/integration/aggregate_test.go index 824c06f993..c7cbcfc7d5 100644 --- a/x/mongo/driver/integration/aggregate_test.go +++ b/x/mongo/driver/integration/aggregate_test.go @@ -84,9 +84,13 @@ func TestAggregate(t *testing.T) { op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)). Collection(collName).Database(dbName).Deployment(top).ServerSelector(&serverselector.Write{}). CommandMonitor(monitor).BatchSize(2) - err = op.Execute(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err = op.Execute(ctx) noerr(t, err) - batchCursor, err := op.Result(driver.CursorOptions{MaxTimeMS: 10, BatchSize: 2, CommandMonitor: monitor}) + batchCursor, err := op.Result(driver.CursorOptions{BatchSize: 2, CommandMonitor: monitor}) noerr(t, err) var e *event.CommandStartedEvent diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 84bf6a9fe1..61110e5467 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -11,6 +11,7 @@ import ( "context" "errors" "fmt" + "math" "net" "strconv" "strings" @@ -48,6 +49,12 @@ var ( ErrNonPrimaryReadPref = errors.New("read preference in a transaction must be primary") // errDatabaseNameEmpty occurs when a database name is not provided. errDatabaseNameEmpty = errors.New("database name cannot be empty") + // errEmptyWriteConcern indicates that a write concern has no fields set. + errEmptyWriteConcern = errors.New("a write concern must have at least one field set") + // errNegativeW indicates that a negative integer `w` field was specified. + errNegativeW = errors.New("write concern `w` field cannot be a negative number") + // errInconsistent indicates that an inconsistent write concern was specified. + errInconsistent = errors.New("a write concern cannot have both w=0 and j=true") ) const ( @@ -280,9 +287,6 @@ type Operation struct { // read preference will not be added to the command on wire versions < 13. IsOutputAggregate bool - // MaxTime specifies the maximum amount of time to allow the operation to run on the server. - MaxTime *time.Duration - // Timeout is the amount of time that this operation can execute before returning an error. The default value // nil, which means that the timeout of the operation's caller will be used. Timeout *time.Duration @@ -293,6 +297,10 @@ type Operation struct { // OP_MSG as well as for logging server selection data. Name string + // OmitMaxTimeMS will ensure that wire messages sent to the server in service + // of the operation do not contain a maxTimeMS field. + OmitMaxTimeMS bool + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -408,6 +416,9 @@ func (op Operation) getServerAndConnection( requestID int32, deprioritized []description.Server, ) (Server, *mnet.Connection, error) { + ctx, cancel := csot.WithServerSelectionTimeout(ctx, op.Deployment.GetServerSelectionTimeout()) + defer cancel() + server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && @@ -485,15 +496,8 @@ func (op Operation) Execute(ctx context.Context) error { return err } - // If no deadline is set on the passed-in context, op.Timeout is set, and context is not already - // a Timeout context, honor op.Timeout in new Timeout context for operation execution. - if _, deadlineSet := ctx.Deadline(); !deadlineSet && op.Timeout != nil && !csot.IsTimeoutContext(ctx) { - newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *op.Timeout) - // Redefine ctx to be the new timeout-derived context. - ctx = newCtx - // Cancel the timeout-derived context at the end of Execute to avoid a context leak. - defer cancelFunc() - } + ctx, cancel := csot.WithTimeout(ctx, op.Timeout) + defer cancel() if op.Client != nil { if err := op.Client.StartCommand(); err != nil { @@ -1181,6 +1185,7 @@ func (op Operation) addBatchArray(dst []byte) []byte { } func (op Operation) createLegacyHandshakeWireMessage( + ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, @@ -1225,7 +1230,7 @@ func (op Operation) createLegacyHandshakeWireMessage( return dst, info, err } - dst, err = op.addWriteConcern(dst, desc) + dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { return dst, info, err } @@ -1297,7 +1302,7 @@ func (op Operation) createMsgWireMessage( if err != nil { return dst, info, err } - dst, err = op.addWriteConcern(dst, desc) + dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { return dst, info, err } @@ -1364,7 +1369,7 @@ func (op Operation) createWireMessage( requestID int32, ) ([]byte, startedInformation, error) { if isLegacyHandshake(op, desc) { - return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + return op.createLegacyHandshakeWireMessage(ctx, maxTimeMS, dst, desc) } return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) @@ -1474,7 +1479,54 @@ func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil } -func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) { +func marshalBSONWriteConcern(wc writeconcern.WriteConcern, wtimeout time.Duration) (bson.Type, []byte, error) { + var elems []byte + if wc.W != nil { + // Only support string or int values for W. That aligns with the + // documentation and the behavior of other functions, like Acknowledged. + switch w := wc.W.(type) { + case int: + if w < 0 { + return 0, nil, errNegativeW + } + + // If Journal=true and W=0, return an error because that write + // concern is ambiguous. + if wc.Journal != nil && *wc.Journal && w == 0 { + return 0, nil, errInconsistent + } + + // Check for lower and upper bounds on architecture-dependent int. + if w > math.MaxInt32 { + return 0, nil, fmt.Errorf("WriteConcern.W overflows int32: %v", wc.W) + } + + elems = bsoncore.AppendInt32Element(elems, "w", int32(w)) + case string: + elems = bsoncore.AppendStringElement(elems, "w", w) + default: + return 0, + nil, + fmt.Errorf("WriteConcern.W must be a string or int, but is a %T", wc.W) + } + } + + if wc.Journal != nil { + elems = bsoncore.AppendBooleanElement(elems, "j", *wc.Journal) + } + + if wtimeout != 0 { + elems = bsoncore.AppendInt64Element(elems, "wtimeout", int64(wtimeout/time.Millisecond)) + } + + if len(elems) == 0 { + return 0, nil, errEmptyWriteConcern + } + + return bson.TypeEmbeddedDocument, bsoncore.BuildDocument(nil, elems), nil +} + +func (op Operation) addWriteConcern(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, op.MinimumWriteConcernWireVersion)) { @@ -1485,15 +1537,27 @@ func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) return dst, nil } - t, data, err := wc.MarshalBSONValue() - if errors.Is(err, writeconcern.ErrEmptyWriteConcern) { + // The specifications for committing a transaction states: + // + // > if the modified write concern does not include a wtimeout value, drivers + // > MUST also apply wtimeout: 10000 to the write concern in order to avoid + // > waiting forever (or until a socket timeout) if the majority write concern + // > cannot be satisfied. + var wtimeout time.Duration + if _, ok := ctx.Deadline(); op.Client != nil && op.Timeout == nil && !ok { + wtimeout = op.Client.CurrentWTimeout + } + + typ, wcBSON, err := marshalBSONWriteConcern(*wc, wtimeout) + if errors.Is(err, errEmptyWriteConcern) { return dst, nil } + if err != nil { return dst, err } - return append(bsoncore.AppendHeader(dst, bsoncore.Type(t), "writeConcern"), data...), nil + return append(bsoncore.AppendHeader(dst, bsoncore.Type(typ), "writeConcern"), wcBSON...), nil } func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) { @@ -1557,34 +1621,29 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (uint64, error) { - if csot.IsTimeoutContext(ctx) { - if deadline, ok := ctx.Deadline(); ok { - remainingTimeout := time.Until(deadline) - - // Always round up to the next millisecond value so we never truncate the calculated - // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) - if maxTimeMS <= 0 { - return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", - remainingTimeout, - ErrDeadlineWouldBeExceeded, - rttStats) - } + if op.OmitMaxTimeMS { + return 0, nil + } - return uint64(maxTimeMS), nil - } - } else if op.MaxTime != nil { - // Users are not allowed to pass a negative value as MaxTime. A value of 0 would indicate - // no timeout and is allowed. - if *op.MaxTime < 0 { - return 0, ErrNegativeMaxTime - } - // Always round up to the next millisecond value so we never truncate the requested - // MaxTime value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - return uint64((*op.MaxTime + (time.Millisecond - 1)) / time.Millisecond), nil + deadline, ok := ctx.Deadline() + if !ok { + return 0, nil } - return 0, nil + + remainingTimeout := time.Until(deadline) + + // Always round up to the next millisecond value so we never truncate the calculated + // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) + if maxTimeMS <= 0 { + return 0, fmt.Errorf( + "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", + remainingTimeout, + ErrDeadlineWouldBeExceeded, + rttStats) + } + + return uint64(maxTimeMS), nil } // updateClusterTimes updates the cluster times for the session and cluster clock attached to this diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 3fe4ca2fe3..92c0186a49 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -30,7 +30,6 @@ type Aggregate struct { collation bsoncore.Document comment bsoncore.Value hint bsoncore.Value - maxTime *time.Duration pipeline bsoncore.Document session *session.Client clock *session.ClusterClock @@ -109,7 +108,6 @@ func (a *Aggregate) Execute(ctx context.Context) error { MinimumWriteConcernWireVersion: 5, ServerAPI: a.serverAPI, IsOutputAggregate: a.hasOutputStage, - MaxTime: a.maxTime, Timeout: a.timeout, Name: driverutil.AggregateOp, }.Execute(ctx) @@ -224,16 +222,6 @@ func (a *Aggregate) Hint(hint bsoncore.Value) *Aggregate { return a } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (a *Aggregate) MaxTime(maxTime *time.Duration) *Aggregate { - if a == nil { - a = new(Aggregate) - } - - a.maxTime = maxTime - return a -} - // Pipeline determines how data is transformed for an aggregation. func (a *Aggregate) Pipeline(pipeline bsoncore.Document) *Aggregate { if a == nil { diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 42a79e2f56..b014affd15 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -9,7 +9,6 @@ package operation import ( "context" "errors" - "time" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" @@ -22,7 +21,6 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { - maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -63,7 +61,6 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { Crypt: ct.crypt, Database: ct.database, Deployment: ct.deployment, - MaxTime: ct.maxTime, Selector: ct.selector, WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, @@ -81,16 +78,6 @@ func (ct *CommitTransaction) command(dst []byte, _ description.SelectedServer) ( return dst, nil } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (ct *CommitTransaction) MaxTime(maxTime *time.Duration) *CommitTransaction { - if ct == nil { - ct = new(CommitTransaction) - } - - ct.maxTime = maxTime - return ct -} - // RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction. func (ct *CommitTransaction) RecoveryToken(recoveryToken bsoncore.Document) *CommitTransaction { if ct == nil { diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 5625b79bd9..6aac998bf5 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -24,7 +24,6 @@ import ( // Count represents a count operation. type Count struct { - maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -120,7 +119,6 @@ func (c *Count) Execute(ctx context.Context) error { Crypt: c.crypt, Database: c.database, Deployment: c.deployment, - MaxTime: c.maxTime, ReadConcern: c.readConcern, ReadPreference: c.readPreference, Selector: c.selector, @@ -150,16 +148,6 @@ func (c *Count) command(dst []byte, _ description.SelectedServer) ([]byte, error return dst, nil } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (c *Count) MaxTime(maxTime *time.Duration) *Count { - if c == nil { - c = new(Count) - } - - c.maxTime = maxTime - return c -} - // Query determines what results are returned from find. func (c *Count) Query(query bsoncore.Document) *Count { if c == nil { diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 0192379e2b..06c8fd8118 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -25,7 +25,6 @@ import ( type CreateIndexes struct { commitQuorum bsoncore.Value indexes bsoncore.Document - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -112,7 +111,6 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { Crypt: ci.crypt, Database: ci.database, Deployment: ci.deployment, - MaxTime: ci.maxTime, Selector: ci.selector, WriteConcern: ci.writeConcern, ServerAPI: ci.serverAPI, @@ -158,16 +156,6 @@ func (ci *CreateIndexes) Indexes(indexes bsoncore.Document) *CreateIndexes { return ci } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (ci *CreateIndexes) MaxTime(maxTime *time.Duration) *CreateIndexes { - if ci == nil { - ci = new(CreateIndexes) - } - - ci.maxTime = maxTime - return ci -} - // Session sets the session for this operation. func (ci *CreateIndexes) Session(session *session.Client) *CreateIndexes { if ci == nil { diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index a13bd2b7b4..a59e4ced35 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -25,7 +25,6 @@ import ( type Distinct struct { collation bsoncore.Document key *string - maxTime *time.Duration query bsoncore.Document session *session.Client clock *session.ClusterClock @@ -99,7 +98,6 @@ func (d *Distinct) Execute(ctx context.Context) error { Crypt: d.crypt, Database: d.database, Deployment: d.deployment, - MaxTime: d.maxTime, ReadConcern: d.readConcern, ReadPreference: d.readPreference, Selector: d.selector, @@ -150,16 +148,6 @@ func (d *Distinct) Key(key string) *Distinct { return d } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (d *Distinct) MaxTime(maxTime *time.Duration) *Distinct { - if d == nil { - d = new(Distinct) - } - - d.maxTime = maxTime - return d -} - // Query specifies which documents to return distinct values from. func (d *Distinct) Query(query bsoncore.Document) *Distinct { if d == nil { diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 597d04ac88..a758f34970 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -24,7 +24,6 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { index *string - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -95,7 +94,6 @@ func (di *DropIndexes) Execute(ctx context.Context) error { Crypt: di.crypt, Database: di.database, Deployment: di.deployment, - MaxTime: di.maxTime, Selector: di.selector, WriteConcern: di.writeConcern, ServerAPI: di.serverAPI, @@ -123,16 +121,6 @@ func (di *DropIndexes) Index(index string) *DropIndexes { return di } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (di *DropIndexes) MaxTime(maxTime *time.Duration) *DropIndexes { - if di == nil { - di = new(DropIndexes) - } - - di.maxTime = maxTime - return di -} - // Session sets the session for this operation. func (di *DropIndexes) Session(session *session.Client) *DropIndexes { if di == nil { diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 1e34b8da8a..bdbad6d610 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -35,7 +35,6 @@ type Find struct { let bsoncore.Document limit *int64 max bsoncore.Document - maxTime *time.Duration min bsoncore.Document noCursorTimeout *bool oplogReplay *bool @@ -100,7 +99,6 @@ func (f *Find) Execute(ctx context.Context) error { Crypt: f.crypt, Database: f.database, Deployment: f.deployment, - MaxTime: f.maxTime, ReadConcern: f.readConcern, ReadPreference: f.readPreference, Selector: f.selector, @@ -299,16 +297,6 @@ func (f *Find) Max(max bsoncore.Document) *Find { return f } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (f *Find) MaxTime(maxTime *time.Duration) *Find { - if f == nil { - f = new(Find) - } - - f.maxTime = maxTime - return f -} - // Min sets an inclusive lower bound for a specific index. func (f *Find) Min(min bsoncore.Document) *Find { if f == nil { diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 12d241f710..51af9ffbcf 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -29,7 +29,6 @@ type FindAndModify struct { collation bsoncore.Document comment bsoncore.Value fields bsoncore.Document - maxTime *time.Duration newDocument *bool query bsoncore.Document remove *bool @@ -137,7 +136,6 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { CommandMonitor: fam.monitor, Database: fam.database, Deployment: fam.deployment, - MaxTime: fam.maxTime, Selector: fam.selector, WriteConcern: fam.writeConcern, Crypt: fam.crypt, @@ -265,16 +263,6 @@ func (fam *FindAndModify) Fields(fields bsoncore.Document) *FindAndModify { return fam } -// MaxTime specifies the maximum amount of time to allow the operation to run on the server. -func (fam *FindAndModify) MaxTime(maxTime *time.Duration) *FindAndModify { - if fam == nil { - fam = new(FindAndModify) - } - - fam.maxTime = maxTime - return fam -} - // NewDocument specifies whether to return the modified document or the original. Defaults to false (return original). func (fam *FindAndModify) NewDocument(newDocument bool) *FindAndModify { if fam == nil { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 8e6c59de38..9a3993120f 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -47,6 +47,7 @@ type Hello struct { maxAwaitTimeMS *int64 serverAPI *driver.ServerAPIOptions loadBalanced bool + omitMaxTimeMS bool res bsoncore.Document } @@ -590,7 +591,8 @@ func (h *Hello) createOperation() driver.Operation { h.res = info.ServerResponse return nil }, - ServerAPI: h.serverAPI, + ServerAPI: h.serverAPI, + OmitMaxTimeMS: h.omitMaxTimeMS, } if isLegacyHandshake(h.serverAPI, h.loadBalanced) { @@ -650,3 +652,15 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, func (h *Hello) FinishHandshake(context.Context, *mnet.Connection) error { return nil } + +// OmitMaxTimeMS will ensure maxTimMS is not included in the wire message +// constructed to send a hello request. +func (h *Hello) OmitMaxTimeMS(val bool) *Hello { + if h == nil { + h = new(Hello) + } + + h.omitMaxTimeMS = val + + return h +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index d4cbe8a337..a14873a7ac 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -22,7 +22,6 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { batchSize *int32 - maxTime *time.Duration session *session.Client clock *session.ClusterClock collection string @@ -76,7 +75,6 @@ func (li *ListIndexes) Execute(ctx context.Context) error { CommandMonitor: li.monitor, Database: li.database, Deployment: li.deployment, - MaxTime: li.maxTime, Selector: li.selector, Crypt: li.crypt, Legacy: driver.LegacyListIndexes, @@ -113,16 +111,6 @@ func (li *ListIndexes) BatchSize(batchSize int32) *ListIndexes { return li } -// MaxTime specifies the maximum amount of time to allow the query to run on the server. -func (li *ListIndexes) MaxTime(maxTime *time.Duration) *ListIndexes { - if li == nil { - li = new(ListIndexes) - } - - li.maxTime = maxTime - return li -} - // Session sets the session for this operation. func (li *ListIndexes) Session(session *session.Client) *ListIndexes { if li == nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 0e3da7007c..f209134b79 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -230,7 +230,8 @@ func TestOperation(t *testing.T) { want := bsoncore.AppendDocumentElement(nil, "writeConcern", bsoncore.BuildDocumentFromElements( nil, bsoncore.AppendStringElement(nil, "w", "majority"), )) - got, err := Operation{WriteConcern: writeconcern.Majority()}.addWriteConcern(nil, description.SelectedServer{}) + got, err := Operation{WriteConcern: writeconcern.Majority()}. + addWriteConcern(context.Background(), nil, description.SelectedServer{}) noerr(t, err) if !bytes.Equal(got, want) { t.Errorf("WriteConcern elements do not match. got %v; want %v", got, want) @@ -270,15 +271,12 @@ func TestOperation(t *testing.T) { }) t.Run("calculateMaxTimeMS", func(t *testing.T) { var ( - timeout = 5 * time.Second - maxTime = 2 * time.Second - negMaxTime = -2 * time.Second - shortRTT = 50 * time.Millisecond - longRTT = 10 * time.Second - verShortRTT = 400 * time.Microsecond + timeout = 5 * time.Second + shortRTT = 50 * time.Millisecond + longRTT = 10 * time.Second ) - timeoutCtx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) + timeoutCtx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() testCases := []struct { @@ -293,43 +291,14 @@ func TestOperation(t *testing.T) { }{ { name: "uses context deadline and rtt90 with timeout", - op: Operation{MaxTime: &maxTime}, ctx: timeoutCtx, rttMin: shortRTT, rttStats: "", want: 5000, err: nil, }, - { - name: "uses MaxTime without timeout", - op: Operation{MaxTime: &maxTime}, - ctx: context.Background(), - rttMin: longRTT, - rttStats: "", - want: 2000, - err: nil, - }, - { - name: "errors when remaining timeout is less than rtt90", - op: Operation{MaxTime: &maxTime}, - ctx: timeoutCtx, - rttMin: timeout, - rttStats: "", - want: 0, - err: ErrDeadlineWouldBeExceeded, - }, - { - name: "errors when MaxTime is negative", - op: Operation{MaxTime: &negMaxTime}, - ctx: context.Background(), - rttMin: longRTT, - rttStats: "", - want: 0, - err: ErrNegativeMaxTime, - }, { name: "sub millisecond rtt should round up", - op: Operation{MaxTime: &verShortRTT}, ctx: context.Background(), rttMin: longRTT, rttStats: "", @@ -651,7 +620,7 @@ func TestOperation(t *testing.T) { assert.NotNil(t, err, "expected an error from Execute(), got nil") // Assert that error is just context deadline exceeded and is therefore not a driver.Error marked // with the TransientTransactionError label. - assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) }) t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { conn := mnet.NewConnection(&mockConnection{}) @@ -710,18 +679,24 @@ type mockDeployment struct { selector description.ServerSelector } returns struct { - server Server - err error - retry bool - kind description.TopologyKind + server Server + err error + retry bool + kind description.TopologyKind + serverSelectionTimeout time.Duration } } func (m *mockDeployment) SelectServer(_ context.Context, desc description.ServerSelector) (Server, error) { m.params.selector = desc + return m.returns.server, m.returns.err } +func (m *mockDeployment) GetServerSelectionTimeout() time.Duration { + return m.returns.serverSelectionTimeout +} + func (m *mockDeployment) Kind() description.TopologyKind { return m.returns.kind } type mockServerSelector struct{} @@ -974,3 +949,67 @@ func TestFilterDeprioritizedServers(t *testing.T) { }) } } + +func TestMarshalBSONWriteConcern(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + writeConcern writeconcern.WriteConcern + wantBSONType bson.Type + wtimeout time.Duration + want bson.D + wantErr string + }{ + { + name: "empty", + writeConcern: writeconcern.WriteConcern{}, + wantBSONType: 0x0, + want: nil, + wtimeout: 0, + wantErr: "a write concern must have at least one field set", + }, + { + name: "journal only", + writeConcern: *writeconcern.Journaled(), + wantBSONType: bson.TypeEmbeddedDocument, + want: bson.D{{"j", true}}, + wtimeout: 0, + wantErr: "a write concern must have at least one field set", + }, + { + name: "journal and wtimout", + writeConcern: *writeconcern.Journaled(), + wtimeout: 10 * time.Millisecond, + wantBSONType: bson.TypeEmbeddedDocument, + want: bson.D{{"j", true}, {"wtimeout", int64(10 * time.Millisecond / time.Millisecond)}}, + wantErr: "a write concern must have at least one field set", + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + gotBSONType, gotBSON, gotErr := marshalBSONWriteConcern(test.writeConcern, test.wtimeout) + assert.Equal(t, test.wantBSONType, gotBSONType) + + wantBSON := []byte(nil) + + if test.want != nil { + var err error + + wantBSON, err = bson.Marshal(test.want) + assert.NoError(t, err) + } + + assert.Equal(t, wantBSON, gotBSON) + + if gotErr != nil { + assert.EqualError(t, gotErr, test.wantErr) + } + }) + } +} diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index d535ec54c9..4228c4e98f 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -57,6 +57,8 @@ const ( Aborted ) +const defaultWriteConcernTimeout = 10_000 * time.Millisecond + // String implements the fmt.Stringer interface. func (s TransactionState) String() string { switch s { @@ -104,16 +106,15 @@ type Client struct { // options for the current transaction // most recently set by transactionopt - CurrentRc *readconcern.ReadConcern - CurrentRp *readpref.ReadPref - CurrentWc *writeconcern.WriteConcern - CurrentMct *time.Duration + CurrentRc *readconcern.ReadConcern + CurrentRp *readpref.ReadPref + CurrentWc *writeconcern.WriteConcern + CurrentWTimeout time.Duration // default transaction options - transactionRc *readconcern.ReadConcern - transactionRp *readpref.ReadPref - transactionWc *writeconcern.WriteConcern - transactionMaxCommitTime *time.Duration + transactionRc *readconcern.ReadConcern + transactionRp *readpref.ReadPref + transactionWc *writeconcern.WriteConcern pool *Pool TransactionState TransactionState @@ -189,9 +190,6 @@ func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (* if mergedOpts.DefaultWriteConcern != nil { c.transactionWc = mergedOpts.DefaultWriteConcern } - if mergedOpts.DefaultMaxCommitTime != nil { - c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime - } if mergedOpts.Snapshot != nil { c.Snapshot = *mergedOpts.Snapshot } @@ -399,7 +397,6 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { c.CurrentRc = opts.ReadConcern c.CurrentRp = opts.ReadPreference c.CurrentWc = opts.WriteConcern - c.CurrentMct = opts.MaxCommitTime } if c.CurrentRc == nil { @@ -414,10 +411,6 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { c.CurrentWc = c.transactionWc } - if c.CurrentMct == nil { - c.CurrentMct = c.transactionMaxCommitTime - } - if !c.CurrentWc.Acknowledged() { _ = c.clearTransactionOpts() return ErrUnackWCUnsupported @@ -449,21 +442,22 @@ func (c *Client) CommitTransaction() error { return nil } -// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a -// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a -// retryable error or after a successful commit transaction operation. +// UpdateCommitTransactionWriteConcern will set the write concern to majority. +// This should be called after a commit transaction operation fails with a +// retryable error or after a successful commit transaction operation +// +// Per the transaction specifications, when commitTransaction is retried, if +// the modified write concern does not include a "wtimeout" value, drivers +// MUST apply "wtimeout: 10000" to the write concern in order to avoid waiting +// forever (oruntil a socket timeout) if the majority write concern cannot be +// satisfied. This field abstracts that functionality. For more information, +// see SPEC-1185. func (c *Client) UpdateCommitTransactionWriteConcern() { - wc := &writeconcern.WriteConcern{} - timeout := 10 * time.Second - if c.CurrentWc != nil { - *wc = *c.CurrentWc - if c.CurrentWc.WTimeout != 0 { - timeout = c.CurrentWc.WTimeout - } + c.CurrentWc = &writeconcern.WriteConcern{ + W: "majority", } - wc.W = "majority" - wc.WTimeout = timeout - c.CurrentWc = wc + + c.CurrentWTimeout = defaultWriteConcernTimeout } // CheckAbortTransaction checks to see if allowed to abort transaction and returns diff --git a/x/mongo/driver/session/options.go b/x/mongo/driver/session/options.go index ee7c301d64..67749f09cb 100644 --- a/x/mongo/driver/session/options.go +++ b/x/mongo/driver/session/options.go @@ -7,8 +7,6 @@ package session import ( - "time" - "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -20,7 +18,6 @@ type ClientOptions struct { DefaultReadConcern *readconcern.ReadConcern DefaultWriteConcern *writeconcern.WriteConcern DefaultReadPreference *readpref.ReadPref - DefaultMaxCommitTime *time.Duration Snapshot *bool } @@ -29,7 +26,6 @@ type TransactionOptions struct { ReadConcern *readconcern.ReadConcern WriteConcern *writeconcern.WriteConcern ReadPreference *readpref.ReadPref - MaxCommitTime *time.Duration } func mergeClientOptions(opts ...*ClientOptions) *ClientOptions { @@ -50,9 +46,6 @@ func mergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.DefaultWriteConcern != nil { c.DefaultWriteConcern = opt.DefaultWriteConcern } - if opt.DefaultMaxCommitTime != nil { - c.DefaultMaxCommitTime = opt.DefaultMaxCommitTime - } if opt.Snapshot != nil { c.Snapshot = opt.Snapshot } diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index 62283d2156..d65c97ca3d 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -208,7 +208,7 @@ func runCMAPTest(t *testing.T, testFileName string) { } })) - s := NewServer("mongodb://fake", bson.NewObjectID(), sOpts...) + s := NewServer("mongodb://fake", bson.NewObjectID(), defaultConnectionTimeout, sOpts...) s.state = serverConnected require.NoError(t, err, "error connecting connection pool") defer s.pool.close(context.Background()) @@ -274,7 +274,6 @@ func runCMAPTest(t *testing.T, testFileName string) { } checkEvents(t, test.Events, testInfo.finalEventChan, test.Ignore) - } func checkEvents(t *testing.T, expectedEvents []cmapEvent, actualEvents chan *event.PoolEvent, ignoreEvents []string) { @@ -290,7 +289,6 @@ func checkEvents(t *testing.T, expectedEvents []cmapEvent, actualEvents chan *ev } if expectedEvent.Address != nil { - if expectedEvent.Address == float64(42) { // can be any address if validEvent.Address == "" { t.Errorf("expected address in event, instead received none in %v", expectedEvent.EventType) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 43d45c1515..cd35c6f66d 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -56,8 +56,6 @@ type connection struct { addr address.Address idleTimeout time.Duration idleDeadline atomic.Value // Stores a time.Time - readTimeout time.Duration - writeTimeout time.Duration desc description.Server helloRTT time.Duration compressor wiremessage.CompressorID @@ -65,13 +63,13 @@ type connection struct { zstdLevel int connectDone chan struct{} config *connectionConfig - cancelConnectContext context.CancelFunc connectContextMade chan struct{} canStream bool currentlyStreaming bool - connectContextMutex sync.Mutex - cancellationListener cancellationListener - serverConnectionID *int64 // the server's ID for this client's connection + cancellationListener contextListener + connectListener contextListener // Cancels blocking ops during connect + serverConnectionID *int64 // the server's ID for this client's connection + prevCanceled atomic.Value // pool related fields pool *pool @@ -90,12 +88,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { id: id, addr: addr, idleTimeout: cfg.idleTimeout, - readTimeout: cfg.readTimeout, - writeTimeout: cfg.writeTimeout, connectDone: make(chan struct{}), config: cfg, connectContextMade: make(chan struct{}), - cancellationListener: newCancellListener(), + cancellationListener: newContextDoneListener(), + connectListener: newNonBlockingContextDoneListener(), } // Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered // at any point during connection establishment can be processed without the connection being considered stale. @@ -141,6 +138,7 @@ func (c *connection) connect(ctx context.Context) (err error) { return nil } + defer c.closeConnectContext() defer close(c.connectDone) // If connect returns an error, set the connection status as disconnected and close the @@ -165,35 +163,17 @@ func (c *connection) connect(ctx context.Context) (err error) { // cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket // establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid // holding the lock longer than necessary. - c.connectContextMutex.Lock() - var handshakeCtx context.Context - handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx) - c.connectContextMutex.Unlock() + ctx, cancel := context.WithCancel(ctx) + defer cancel() - dialCtx := handshakeCtx - var dialCancel context.CancelFunc - if c.config.connectTimeout != 0 { - dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout) - defer dialCancel() - } - - defer func() { - var cancelFn context.CancelFunc + go func() { + defer cancel() - c.connectContextMutex.Lock() - cancelFn = c.cancelConnectContext - c.cancelConnectContext = nil - c.connectContextMutex.Unlock() - - if cancelFn != nil { - cancelFn() - } + c.connectListener.Listen(ctx, func() {}) }() - close(c.connectContextMade) - // Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case. - tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String()) + tempNc, err := c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String()) if err != nil { return ConnectionError{Wrapped: err, init: true} } @@ -209,7 +189,8 @@ func (c *connection) connect(ctx context.Context) (err error) { DisableEndpointChecking: c.config.disableOCSPEndpointCheck, HTTPClient: c.config.httpClient, } - tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + if err != nil { return ConnectionError{Wrapped: err, init: true} } @@ -226,10 +207,9 @@ func (c *connection) connect(ctx context.Context) (err error) { handshakeStartTime := time.Now() iconn := initConnection{c} - handshakeConn := mnet.NewConnection(iconn) - handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn) + handshakeInfo, err = handshaker.GetHandshakeInformation(ctx, c.addr, handshakeConn) if err == nil { // We only need to retain the Description field as the connection's description. The authentication-related // fields in handshakeInfo are tracked by the handshaker if necessary. @@ -253,7 +233,7 @@ func (c *connection) connect(ctx context.Context) (err error) { // If we successfully finished the first part of the handshake and verified LB state, continue with the rest of // the handshake. - err = handshaker.FinishHandshake(handshakeCtx, handshakeConn) + err = handshaker.FinishHandshake(ctx, handshakeConn) } // We have a failed handshake here @@ -299,16 +279,8 @@ func (c *connection) wait() { } func (c *connection) closeConnectContext() { - <-c.connectContextMade - var cancelFn context.CancelFunc - - c.connectContextMutex.Lock() - cancelFn = c.cancelConnectContext - c.cancelConnectContext = nil - c.connectContextMutex.Unlock() - - if cancelFn != nil { - cancelFn() + if c.connectListener != nil { + c.connectListener.StopListening() } } @@ -347,17 +319,7 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { } } - var deadline time.Time - if c.writeTimeout != 0 { - deadline = time.Now().Add(c.writeTimeout) - } - - var contextDeadlineUsed bool - if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { - contextDeadlineUsed = true - deadline = dl - } - + deadline, contextDeadlineUsed := ctx.Deadline() if err := c.nc.SetWriteDeadline(deadline); err != nil { return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"} } @@ -401,17 +363,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { } } - var deadline time.Time - if c.readTimeout != 0 { - deadline = time.Now().Add(c.readTimeout) - } - - var contextDeadlineUsed bool - if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) { - contextDeadlineUsed = true - deadline = dl - } - + deadline, contextDeadlineUsed := ctx.Deadline() if err := c.nc.SetReadDeadline(deadline); err != nil { return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"} } @@ -484,6 +436,12 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, } func (c *connection) close() error { + // Stop any blocking operations occurring in connect(), but await closing the + // connections directly before closing the connection context. This ensures + // that closing a connection will manifest as an io.EOF error, avoiding + // non-deterministic connection closure errors. + defer c.closeConnectContext() + // Overwrite the connection state as the first step so only the first close call will execute. if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) { return nil @@ -535,11 +493,6 @@ func (c *connection) getCurrentlyStreaming() bool { return c.currentlyStreaming } -func (c *connection) setSocketTimeout(timeout time.Duration) { - c.readTimeout = timeout - c.writeTimeout = timeout -} - func (c *connection) ID() string { return c.id } @@ -548,6 +501,14 @@ func (c *connection) ServerConnectionID() *int64 { return c.serverConnectionID } +func (c *connection) previousCanceled() bool { + if val := c.prevCanceled.Load(); val != nil { + return val.(bool) + } + + return false +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. @@ -851,47 +812,3 @@ func configureTLS(ctx context.Context, } return client, nil } - -// TODO: Naming? - -// cancellListener listens for context cancellation and notifies listeners via a -// callback function. -type cancellListener struct { - aborted bool - done chan struct{} -} - -// newCancellListener constructs a cancellListener. -func newCancellListener() *cancellListener { - return &cancellListener{ - done: make(chan struct{}), - } -} - -// Listen blocks until the provided context is cancelled or listening is aborted -// via the StopListening function. If this detects that the context has been -// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback is -// called to abort in-progress work. Even if the context expires, this function -// will block until StopListening is called. -func (c *cancellListener) Listen(ctx context.Context, abortFn func()) { - c.aborted = false - - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.Canceled) { - c.aborted = true - abortFn() - } - - <-c.done - case <-c.done: - } -} - -// StopListening stops the in-progress Listen call. This blocks if there is no -// in-progress Listen call. This function will return true if the provided abort -// callback was called when listening for cancellation on the previous context. -func (c *cancellListener) StopListening() bool { - c.done <- struct{}{} - return c.aborted -} diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index 41533a149a..f45da5d460 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -48,13 +48,10 @@ type Handshaker = driver.Handshaker type generationNumberFn func(serviceID *bson.ObjectID) uint64 type connectionConfig struct { - connectTimeout time.Duration dialer Dialer handshaker Handshaker idleTimeout time.Duration cmdMonitor *event.CommandMonitor - readTimeout time.Duration - writeTimeout time.Duration tlsConfig *tls.Config httpClient *http.Client compressors []string @@ -69,7 +66,6 @@ type connectionConfig struct { func newConnectionConfig(opts ...ConnectionOption) *connectionConfig { cfg := &connectionConfig{ - connectTimeout: 30 * time.Second, dialer: nil, tlsConnectionSource: defaultTLSConnectionSource, httpClient: httputil.DefaultHTTPClient, @@ -107,14 +103,6 @@ func WithCompressors(fn func([]string) []string) ConnectionOption { } } -// WithConnectTimeout configures the maximum amount of time a dial will wait for a -// Connect to complete. The default is 30 seconds. -func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.connectTimeout = fn(c.connectTimeout) - } -} - // WithDialer configures the Dialer to use when making a new connection to MongoDB. func WithDialer(fn func(Dialer) Dialer) ConnectionOption { return func(c *connectionConfig) { @@ -137,20 +125,6 @@ func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption { } } -// WithReadTimeout configures the maximum read time for a connection. -func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.readTimeout = fn(c.readTimeout) - } -} - -// WithWriteTimeout configures the maximum write time for a connection. -func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption { - return func(c *connectionConfig) { - c.writeTimeout = fn(c.writeTimeout) - } -} - // WithTLSConfig configures the TLS options for a connection. func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption { return func(c *connectionConfig) { diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index bcad63acbb..b5158c596d 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -118,7 +118,6 @@ func TestConnection(t *testing.T) { err := conn.connect(context.Background()) assert.Nil(t, err, "error establishing connection: %v", err) - assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") }) t.Run("connect cancelled", func(t *testing.T) { // In the case where connection establishment is cancelled, the closeConnectContext function @@ -149,7 +148,6 @@ func TestConnection(t *testing.T) { // Simulate cancelling connection establishment and assert that this clears the CancelFunc. conn.closeConnectContext() - assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") close(doneChan) wg.Wait() }) @@ -203,154 +201,6 @@ func TestConnection(t *testing.T) { } }) }) - t.Run("connectTimeout is applied correctly", func(t *testing.T) { - testCases := []struct { - name string - contextTimeout time.Duration - connectTimeout time.Duration - maxConnectTime time.Duration - }{ - // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for - // both of the tests declared below. Both tests also specify a 50ms max connect time to provide - // a large buffer for lag and avoid test flakiness. - - {"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 50 * time.Millisecond}, - {"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 50 * time.Millisecond}, - } - - for _, tc := range testCases { - t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) { - // Ensure the initial connection dial can be timed out and the connection propagates the error - // from the dialer in this case. - - connOpts := []ConnectionOption{ - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) { - <-ctx.Done() - return nil, ctx.Err() - }) - }), - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - } - conn := newConnection("", connOpts...) - - var connectErr error - callback := func() bool { - connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) - defer cancel() - - connectErr = conn.connect(connectCtx) - return true - } - assert.Eventually(t, - callback, - tc.maxConnectTime, - time.Millisecond, - "expected timeout to apply to socket establishment after maximum connect time") - - ce, ok := connectErr.(ConnectionError) - assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) - assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v", - context.DeadlineExceeded, ce.Unwrap()) - }) - t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) { - // Ensure the TLS handshake can be timed out and the connection propagates the error from the - // tlsConn in this case. - - // Start a TCP listener on a random port and use the listener address as the - // target for connections. The listener will act as a source of connections - // that never respond, allowing the timeout logic to always trigger. - l, err := net.Listen("tcp", "localhost:0") - assert.Nil(t, err, "net.Listen() error: %q", err) - defer l.Close() - - connOpts := []ConnectionOption{ - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - WithTLSConfig(func(*tls.Config) *tls.Config { - return &tls.Config{ServerName: "test"} - }), - } - conn := newConnection(address.Address(l.Addr().String()), connOpts...) - - var connectErr error - callback := func() bool { - connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) - defer cancel() - - connectErr = conn.connect(connectCtx) - return true - } - assert.Eventually(t, - callback, - tc.maxConnectTime, - time.Millisecond, - "expected timeout to apply to TLS handshake after maximum connect time") - - ce, ok := connectErr.(ConnectionError) - assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) - - isTimeout := func(err error) bool { - if errors.Is(err, context.DeadlineExceeded) { - return true - } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - return false - } - assert.True(t, - isTimeout(ce.Unwrap()), - "expected wrapped error to be a timeout error, but got %q", - ce.Unwrap()) - }) - t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) { - // Ensure that no additional timeout is applied to the handshake after the connection has been - // established. - - var getInfoCtx, finishCtx context.Context - handshaker := &testHandshaker{ - getHandshakeInformation: func(ctx context.Context, _ address.Address, _ *mnet.Connection) (driver.HandshakeInformation, error) { - getInfoCtx = ctx - return driver.HandshakeInformation{}, nil - }, - finishHandshake: func(ctx context.Context, _ *mnet.Connection) error { - finishCtx = ctx - return nil - }, - } - - connOpts := []ConnectionOption{ - WithConnectTimeout(func(time.Duration) time.Duration { - return tc.connectTimeout - }), - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return &net.TCPConn{}, nil - }) - }), - WithHandshaker(func(Handshaker) Handshaker { - return handshaker - }), - } - conn := newConnection("", connOpts...) - - err := conn.connect(context.Background()) - assert.Nil(t, err, "connect error: %v", err) - - assertNoContextTimeout := func(t *testing.T, ctx context.Context) { - t.Helper() - dl, ok := ctx.Deadline() - assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl) - } - assertNoContextTimeout(t, getInfoCtx) - assertNoContextTimeout(t, finishCtx) - }) - } - }) }) t.Run("writeWireMessage", func(t *testing.T) { t.Run("closed connection", func(t *testing.T) { @@ -365,14 +215,10 @@ func TestConnection(t *testing.T) { testCases := []struct { name string ctxDeadline time.Duration - timeout time.Duration deadline time.Time }{ - {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, - {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, - {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, - {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, - {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + {"no deadline", 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, time.Now().Add(6 * time.Second)}, } for _, tc := range testCases { @@ -389,7 +235,7 @@ func TestConnection(t *testing.T) { message: "failed to set write deadline", } tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")} - conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} got := conn.writeWireMessage(ctx, []byte{}) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) @@ -494,14 +340,10 @@ func TestConnection(t *testing.T) { testCases := []struct { name string ctxDeadline time.Duration - timeout time.Duration deadline time.Time }{ - {"no deadline", 0, 0, time.Now().Add(1 * time.Second)}, - {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)}, - {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)}, - {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)}, - {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)}, + {"no deadline", 0, time.Now().Add(1 * time.Second)}, + {"ctx deadline", 5 * time.Second, time.Now().Add(6 * time.Second)}, } for _, tc := range testCases { @@ -518,7 +360,7 @@ func TestConnection(t *testing.T) { message: "failed to set read deadline", } tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")} - conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} _, got := conn.readWireMessage(ctx) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) @@ -786,7 +628,8 @@ func TestConnection(t *testing.T) { addr := bootstrapConnections(t, numConns, func(nc net.Conn) {}) pool := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := pool.ready() assert.Nil(t, err, "pool.connect() error: %v", err) @@ -1205,7 +1048,7 @@ func (d *dialer) lenclosed() int { } type testCancellationListener struct { - listener *cancellListener + listener *contextDoneListener numListen int numStopListening int aborted bool @@ -1215,7 +1058,7 @@ type testCancellationListener struct { // returned by the StopListening method. func newTestCancellationListener(aborted bool) *testCancellationListener { return &testCancellationListener{ - listener: newCancellListener(), + listener: newContextDoneListener(), aborted: aborted, } } diff --git a/x/mongo/driver/topology/context_listener.go b/x/mongo/driver/topology/context_listener.go new file mode 100644 index 0000000000..99c252c87c --- /dev/null +++ b/x/mongo/driver/topology/context_listener.go @@ -0,0 +1,91 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import ( + "context" + "errors" + "sync/atomic" +) + +type contextListener interface { + Listen(context.Context, func()) + StopListening() bool +} + +// contextDoneListener listens for context-ending eventsand notifies listeners +// via a callback function. +type contextDoneListener struct { + aborted atomic.Value + done chan struct{} + blockOnDone bool +} + +var _ contextListener = &contextDoneListener{} + +// newContextDoneListener constructs a contextDoneListener that will block +// when a context is done until StopListening is called. +func newContextDoneListener() *contextDoneListener { + return &contextDoneListener{ + done: make(chan struct{}), + blockOnDone: true, + } +} + +// newNonBlockingContextDoneLIstener constructs a contextDoneListener that +// will not block when a context is done. In this case there are two ways to +// unblock the listener: a finished context or a call to StopListening. +func newNonBlockingContextDoneListener() *contextDoneListener { + return &contextDoneListener{ + done: make(chan struct{}), + blockOnDone: false, + } +} + +// Listen blocks until the provided context is cancelled or listening is aborted +// via the StopListening function. If this detects that the context has been +// cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback +// is called to abort in-progress work. If blockOnDone is true, this function +// will block until StopListening is called, even if the context expires. +func (c *contextDoneListener) Listen(ctx context.Context, abortFn func()) { + c.aborted.Store(false) + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + c.aborted.Store(true) + + abortFn() + } + + if c.blockOnDone { + <-c.done + } + case <-c.done: + } +} + +// StopListening stops the in-progress Listen call. If blockOnDone is true, then +// this blocks if there is no in-progress Listen call. This function will return +// true if the provided abort callback was called when listening for +// cancellation on the previous context. +func (c *contextDoneListener) StopListening() bool { + if c.blockOnDone { + c.done <- struct{}{} + } else { + select { + case c.done <- struct{}{}: + default: + } + } + + if aborted := c.aborted.Load(); aborted != nil { + return aborted.(bool) + } + + return false +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 122e13111c..4a1b82b431 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -78,6 +78,7 @@ type poolConfig struct { PoolMonitor *event.PoolMonitor Logger *logger.Logger handshakeErrFn func(error, uint64, *bson.ObjectID) + ConnectTimeout time.Duration } type pool struct { @@ -122,9 +123,10 @@ type pool struct { conns map[int64]*connection // conns holds all currently open connections. newConnWait wantConnQueue // newConnWait holds all wantConn requests for new connections. - idleMu sync.Mutex // idleMu guards idleConns, idleConnWait - idleConns []*connection // idleConns holds all idle connections. - idleConnWait wantConnQueue // idleConnWait holds all wantConn requests for idle connections. + idleMu sync.Mutex // idleMu guards idleConns, idleConnWait + idleConns []*connection // idleConns holds all idle connections. + idleConnWait wantConnQueue // idleConnWait holds all wantConn requests for idle connections. + connectTimeout time.Duration } // getState returns the current state of the pool. Callers must not hold the stateMu lock. @@ -221,6 +223,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { createConnectionsCond: sync.NewCond(&sync.Mutex{}), conns: make(map[int64]*connection, config.MaxPoolSize), idleConns: make([]*connection, 0, config.MaxPoolSize), + connectTimeout: config.ConnectTimeout, } // minSize must not exceed maxSize if maxSize is not 0 if pool.maxSize != 0 && pool.minSize > pool.maxSize { @@ -1108,9 +1111,26 @@ func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) { } start := time.Now() - // Pass the createConnections context to connect to allow pool close to cancel connection - // establishment so shutdown doesn't block indefinitely if connectTimeout=0. - err := conn.connect(ctx) + // Pass the createConnections context to connect to allow pool close to + // cancel connection establishment so shutdown doesn't block indefinitely if + // connectTimeout=0. + // + // Per the specifications, an explicit value of connectTimeout=0 means the + // timeout is "infinite". + + var cancel context.CancelFunc + + connctx := context.Background() + if p.connectTimeout != 0 { + connctx, cancel = context.WithTimeout(ctx, p.connectTimeout) + } + + err := conn.connect(connctx) + + if cancel != nil { + cancel() + } + if err != nil { w.tryDeliver(nil, err) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 3001aa9b1b..69a7cce726 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -67,7 +67,8 @@ func TestPool(t *testing.T) { }) p1 := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() noerr(t, err) @@ -92,7 +93,9 @@ func TestPool(t *testing.T) { t.Run("calling close multiple times does not panic", func(t *testing.T) { t.Parallel() - p := newPool(poolConfig{}) + p := newPool(poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }) err := p.ready() noerr(t, err) @@ -112,7 +115,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -148,7 +152,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -183,7 +188,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -229,7 +235,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -284,7 +291,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -313,7 +321,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -369,7 +378,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -407,7 +417,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -456,7 +467,9 @@ func TestPool(t *testing.T) { t.Parallel() dialErr := errors.New("create new connection error") - p := newPool(poolConfig{}, WithDialer(func(Dialer) Dialer { + p := newPool(poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }, WithDialer(func(Dialer) Dialer { return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, dialErr }) @@ -493,8 +506,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool( poolConfig{ - Address: address.Address(addr.String()), - MaxIdleTime: time.Millisecond, + Address: address.Address(addr.String()), + MaxIdleTime: time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d }), ) @@ -538,7 +552,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -565,7 +580,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -583,7 +599,9 @@ func TestPool(t *testing.T) { t.Parallel() p := newPool( - poolConfig{}, + poolConfig{ + ConnectTimeout: defaultConnectionTimeout, + }, WithHandshaker(func(Handshaker) Handshaker { return operation.NewHello() }), @@ -632,8 +650,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -672,8 +691,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -727,8 +747,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool( poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 2, + Address: address.Address(addr.String()), + MaxPoolSize: 2, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d }), ) @@ -794,8 +815,9 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxPoolSize: 1, + Address: address.Address(addr.String()), + MaxPoolSize: 1, + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -834,7 +856,8 @@ func TestPool(t *testing.T) { }) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() noerr(t, err) @@ -867,7 +890,8 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -897,7 +921,8 @@ func TestPool(t *testing.T) { }) p1 := newPool(poolConfig{ - Address: address.Address(addr.String()), + Address: address.Address(addr.String()), + ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() noerr(t, err) @@ -927,8 +952,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MaxIdleTime: 100 * time.Millisecond, + Address: address.Address(addr.String()), + MaxIdleTime: 100 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -960,9 +986,10 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 3, - MaxIdleTime: 10 * time.Millisecond, + Address: address.Address(addr.String()), + MinPoolSize: 3, + MaxIdleTime: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1000,8 +1027,9 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 3, + Address: address.Address(addr.String()), + MinPoolSize: 3, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1024,9 +1052,10 @@ func TestPool(t *testing.T) { d := newdialer(&net.Dialer{}) p := newPool(poolConfig{ - Address: address.Address(addr.String()), - MinPoolSize: 20, - MaxPoolSize: 2, + Address: address.Address(addr.String()), + MinPoolSize: 20, + MaxPoolSize: 2, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1052,6 +1081,7 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), // Set the pool's maintain interval to 10ms so that it allows the test to run quickly. MaintainInterval: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) @@ -1102,6 +1132,7 @@ func TestPool(t *testing.T) { MinPoolSize: 3, // Set the pool's maintain interval to 10ms so that it allows the test to run quickly. MaintainInterval: 10 * time.Millisecond, + ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() noerr(t, err) diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 21eafd18f2..03bcc06aa9 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -29,12 +29,9 @@ type rttConfig struct { // the operation takes longer than the interval. interval time.Duration - // The timeout applied to running the "hello" operation. If the timeout is reached while running - // the operation, the RTT sample is discarded. The default is 1 minute. - timeout time.Duration - minRTTWindow time.Duration createConnectionFn func() *connection + connectTimeout time.Duration createOperationFn func(*mnet.Connection) *operation.Hello } @@ -115,7 +112,11 @@ func (r *rttMonitor) start() { for { conn := r.cfg.createConnectionFn() - err := conn.connect(r.ctx) + + ctx, cancel := context.WithTimeout(r.ctx, r.cfg.connectTimeout) + defer cancel() + + err := conn.connect(ctx) // Add an RTT sample from the new connection handshake and start a runHellos() loop if we // successfully established the new connection. Otherwise, close the connection and try to @@ -161,11 +162,7 @@ func (r *rttMonitor) runHellos(conn *connection) { // server or a proxy stops responding to requests on the RTT connection but does not close // the TCP socket, effectively creating an operation that will never complete. We expect // that "connectTimeoutMS" provides at least enough time for a single round trip. - timeout := r.cfg.timeout - if timeout <= 0 { - timeout = conn.config.connectTimeout - } - ctx, cancel := context.WithTimeout(r.ctx, timeout) + ctx, cancel := context.WithTimeout(r.ctx, r.cfg.connectTimeout) start := time.Now() iconn := mnet.NewConnection(initConnection{conn}) diff --git a/x/mongo/driver/topology/rtt_monitor_test.go b/x/mongo/driver/topology/rtt_monitor_test.go index 7abfe024fc..f2677c8979 100644 --- a/x/mongo/driver/topology/rtt_monitor_test.go +++ b/x/mongo/driver/topology/rtt_monitor_test.go @@ -91,7 +91,8 @@ func TestRTTMonitor(t *testing.T) { return newMockSlowConn(makeHelloReply(), 10*time.Millisecond), nil }) rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, + interval: 10 * time.Millisecond, + connectTimeout: defaultConnectionTimeout, createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, @@ -150,7 +151,8 @@ func TestRTTMonitor(t *testing.T) { return newMockSlowConn(makeHelloReply(), 10*time.Millisecond), nil }) rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, + connectTimeout: defaultConnectionTimeout, + interval: 10 * time.Millisecond, createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, @@ -248,8 +250,8 @@ func TestRTTMonitor(t *testing.T) { }() rtt := newRTTMonitor(&rttConfig{ - interval: 10 * time.Millisecond, - timeout: 100 * time.Millisecond, + interval: 10 * time.Millisecond, + connectTimeout: 100 * time.Millisecond, createConnectionFn: func() *connection { return newConnection(address.Address(l.Addr().String())) }, diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 862f9c6d48..8d53dfd62e 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -133,16 +133,12 @@ type Server struct { currentSubscriberID uint64 subscriptionsClosed bool - // heartbeat and cancellation related fields - // globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down. - // heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be - // cancelled automatically during shutdown. - heartbeatLock sync.Mutex - conn *connection - globalCtx context.Context - globalCtxCancel context.CancelFunc - heartbeatCtx context.Context - heartbeatCtxCancel context.CancelFunc + conn *connection + + // Calling StopListening on the heartbeatListner will cancel the context + // passed to the heartbeat check. This will result in the current connection + // being closed. + heartbeatListener contextListener processErrorLock sync.Mutex rttMonitor *rttMonitor @@ -160,9 +156,10 @@ func ConnectServer( addr address.Address, updateCallback updateTopologyCallback, topologyID bson.ObjectID, + connectTimeout time.Duration, opts ...ServerOption, ) (*Server, error) { - srvr := NewServer(addr, topologyID, opts...) + srvr := NewServer(addr, topologyID, connectTimeout, opts...) err := srvr.Connect(updateCallback) if err != nil { return nil, err @@ -172,9 +169,14 @@ func ConnectServer( // NewServer creates a new server. The mongodb server at the address will be monitored // on an internal monitoring goroutine. -func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOption) *Server { - cfg := newServerConfig(opts...) - globalCtx, globalCtxCancel := context.WithCancel(context.Background()) +func NewServer( + addr address.Address, + topologyID bson.ObjectID, + connectTimeout time.Duration, + opts ...ServerOption, +) *Server { + cfg := newServerConfig(connectTimeout, opts...) + s := &Server{ state: serverDisconnected, @@ -187,9 +189,8 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt topologyID: topologyID, - subscribers: make(map[uint64]chan description.Server), - globalCtx: globalCtx, - globalCtxCancel: globalCtxCancel, + subscribers: make(map[uint64]chan description.Server), + heartbeatListener: newNonBlockingContextDoneListener(), } s.desc.Store(newDefaultServerDescription(addr)) rttCfg := &rttConfig{ @@ -197,6 +198,7 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt minRTTWindow: 5 * time.Minute, createConnectionFn: s.createConnection, createOperationFn: s.createBaseOperation, + connectTimeout: connectTimeout, } s.rttMonitor = newRTTMonitor(rttCfg) @@ -211,6 +213,7 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt PoolMonitor: cfg.poolMonitor, Logger: cfg.logger, handshakeErrFn: s.ProcessHandshakeError, + ConnectTimeout: connectTimeout, } connectionOpts := copyConnectionOpts(cfg.connectionOpts) @@ -299,13 +302,9 @@ func (s *Server) Disconnect(ctx context.Context) error { s.updateTopologyCallback.Store((updateTopologyCallback)(nil)) - // Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done - // channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end. - // The done channel is closed before cancelling the check so the update routine() will immediately detect that it - // can stop rather than trying to create new connections until the read from done succeeds. - s.globalCtxCancel() close(s.done) - s.cancelCheck() + + s.heartbeatListener.StopListening() s.pool.close(ctx) @@ -380,7 +379,7 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6 // checking logic above has already determined that this description is not stale. s.updateDescription(newServerDescriptionFromError(s.address, wrappedConnErr, nil)) s.pool.clear(err, serviceID) - s.cancelCheck() + s.heartbeatListener.StopListening() } // Description returns a description of the server as of the last heartbeat. @@ -559,10 +558,65 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces // updateDescription. s.updateDescription(newServerDescriptionFromError(s.address, err, nil)) s.pool.clear(err, serviceID) - s.cancelCheck() + s.heartbeatListener.StopListening() return driver.ConnectionPoolCleared } +type serverChecker interface { + check(ctx context.Context) (description.Server, error) +} + +var _ serverChecker = &Server{} + +// checkServerWithSignal will run the server heartbeat check, canceling if the +// sig channel's buffer is emptied or is closed. +func checkServerWithSignal( + checker serverChecker, + conn *connection, + listener contextListener, +) (description.Server, error) { + // Create a context for the blocking operations associated with checking the + // status of a server. + // + // The Server Monitoring spec already mandates that drivers set and + // dynamically update the read/write timeout of the dedicated connections + // used in monitoring threads, so we rely on that to time out commands + // rather than adding complexity to the behavior of timeoutMS. + ctx, cancel := context.WithCancel(context.Background()) + + defer listener.StopListening() + defer cancel() + + go func(conn *connection) { + defer cancel() + + var aborted bool + listener.Listen(ctx, func() { + aborted = true + }) + + // Close the connection if the listener was stopped before + // checkServerWithSignal terminates. + if !aborted { + if conn == nil { + return + } + + // If the connection exists, we need to wait for it to be connected + // because conn.connect() and conn.close() cannot be called concurrently. + // If the connection wasn't successfully opened, its state was set back + // to disconnected, so calling conn.close() will be a no-op. + conn.closeConnectContext() + conn.wait() + conn.prevCanceled.Store(true) + _ = conn.close() + } + + }(conn) + + return checker.check(ctx) +} + // update handle performing heartbeats and updating any subscribers of the // newest description.Server retrieved. func (s *Server) update() { @@ -587,8 +641,6 @@ func (s *Server) update() { s.subscriptionsClosed = true s.subLock.Unlock() - // We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks - // below detect that the server is being closed, so we can be sure that the connection isn't being used. if s.conn != nil { _ = s.conn.close() } @@ -626,8 +678,9 @@ func (s *Server) update() { previousDescription := s.Description() - // Perform the next check. - desc, err := s.check() + desc, err := checkServerWithSignal(s, s.conn, s.heartbeatListener) + + // The only error returned from checkServerWithSignal is errCheckCancelled. if errors.Is(err, errCheckCancelled) { if atomic.LoadInt64(&s.state) != serverConnected { continue @@ -754,11 +807,6 @@ func (s *Server) updateDescription(desc description.Server) { func (s *Server) createConnection() *connection { opts := copyConnectionOpts(s.cfg.connectionOpts) opts = append(opts, - WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }), - // We override whatever handshaker is currently attached to the options with a basic - // one because need to make sure we don't do auth. WithHandshaker(func(h Handshaker) Handshaker { return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). ServerAPI(s.cfg.serverAPI) @@ -776,48 +824,19 @@ func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption { return optsCopy } -func (s *Server) setupHeartbeatConnection() error { +func (s *Server) setupHeartbeatConnection(ctx context.Context) error { conn := s.createConnection() - // Take the lock when assigning the context and connection because they're accessed by cancelCheck. - s.heartbeatLock.Lock() - if s.heartbeatCtxCancel != nil { - // Ensure the previous context is cancelled to avoid a leak. - s.heartbeatCtxCancel() - } - s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx) s.conn = conn - s.heartbeatLock.Unlock() - - return s.conn.connect(s.heartbeatCtx) -} - -// cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server. -func (s *Server) cancelCheck() { - var conn *connection - // Take heartbeatLock for mutual exclusion with the checks in the update function. - s.heartbeatLock.Lock() - if s.heartbeatCtx != nil { - s.heartbeatCtxCancel() - } - conn = s.conn - s.heartbeatLock.Unlock() + if s.cfg.connectTimeout != 0 { + var cancelFn context.CancelFunc + ctx, cancelFn = context.WithTimeout(ctx, s.cfg.connectTimeout) - if conn == nil { - return + defer cancelFn() } - // If the connection exists, we need to wait for it to be connected because conn.connect() and - // conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its - // state was set back to disconnected, so calling conn.close() will be a no-op. - conn.closeConnectContext() - conn.wait() - _ = conn.close() -} - -func (s *Server) checkWasCancelled() bool { - return s.heartbeatCtx.Err() != nil + return s.conn.connect(ctx) } func (s *Server) createBaseOperation(conn *mnet.Connection) *operation.Hello { @@ -843,24 +862,119 @@ func isStreamable(srv *Server) bool { return srv.Description().Kind != description.Unknown && srv.Description().TopologyVersion != nil } -func (s *Server) check() (description.Server, error) { +func (s *Server) streamable() bool { + return isStreamingEnabled(s) && isStreamable(s) +} + +// getHeartbeatTimeout will return the maximum allowable duration for streaming +// or polling a hello command during server monitoring. +func getHeartbeatTimeout(srv *Server) time.Duration { + if srv.conn.getCurrentlyStreaming() || srv.streamable() { + // If connectTimeoutMS=0, the operation timeout should be infinite. + // Otherwise, it is connectTimeoutMS + heartbeatFrequencyMS to account for + // the fact that the query will block for heartbeatFrequencyMS + // server-side. + streamingTO := srv.cfg.connectTimeout + if streamingTO != 0 { + streamingTO += srv.cfg.heartbeatInterval + } + + return streamingTO + } + + // The server doesn't support the awaitable protocol. Set the timeout to + // connectTimeoutMS and execute a regular heartbeat without any additional + // parameters. + return srv.cfg.connectTimeout +} + +// withHeartbeatTimeout will apply the appropriate timeout to the parent context +// for server monitoring. +func withHeartbeatTimeout(parent context.Context, srv *Server) (context.Context, context.CancelFunc) { + var cancel context.CancelFunc + + timeout := getHeartbeatTimeout(srv) + if timeout == 0 { + return parent, cancel + } + + return context.WithTimeout(parent, timeout) +} + +// doHandshake will construct the hello operation use to execute a handshake +// between the client and a server. Depending on the configuration and version, +// this function will either poll, stream, or resume streaming. +func doHandshake(ctx context.Context, srv *Server) (description.Server, error) { + heartbeatConn := mnet.NewConnection(initConnection{srv.conn}) + handshakeOp := srv.createBaseOperation(heartbeatConn) + + // Using timeoutMS in the monitoring and RTT calculation threads would require + // another special case in the code that derives maxTimeMS from timeoutMS + // because the awaitable hello requests sent to 4.4+ servers already have a + // maxAwaitTimeMS field. Adding maxTimeMS also does not help for non-awaitable + // hello commands because we expect them to execute quickly on the server. The + // Server Monitoring spec already mandates that drivers set and dynamically + // update the read/write timeout of the dedicated connections used in + // monitoring threads, so we rely on that to time out commands rather than + // adding complexity to the behavior of timeoutMS. + handshakeOp = handshakeOp.OmitMaxTimeMS(true) + + // Apply monitoring timeout. + ctx, cancel := withHeartbeatTimeout(ctx, srv) + defer cancel() + + // If we are currently streaming, get more data and return the result. + if srv.conn.getCurrentlyStreaming() { + if err := handshakeOp.StreamResponse(ctx, heartbeatConn); err != nil { + return description.Server{}, err + } + + return handshakeOp.Result(srv.address), nil + } + + // If the server supports streaming, update it so the next handshake streams + // the response. + if srv.streamable() { + srv.conn.setCanStream(true) + + maxAwaitTimeMS := int64(srv.cfg.heartbeatInterval) / 1e6 + + handshakeOp = handshakeOp. + TopologyVersion(srv.Description().TopologyVersion). + MaxAwaitTimeMS(maxAwaitTimeMS) + } + + // Perform the handshake. + if err := handshakeOp.Execute(ctx); err != nil { + return description.Server{}, err + } + + return handshakeOp.Result(srv.address), nil +} + +func (s *Server) check(ctx context.Context) (description.Server, error) { var descPtr *description.Server var err error - var duration time.Duration + var execDuration time.Duration start := time.Now() + var previousCanceled bool + if s.conn != nil { + previousCanceled = s.conn.previousCanceled() + } + // Create a new connection if this is the first check, the connection was closed after an error during the previous // check, or the previous check was cancelled. - if s.conn == nil || s.conn.closed() || s.checkWasCancelled() { + if s.conn == nil || s.conn.closed() || previousCanceled { connID := "0" if s.conn != nil { connID = s.conn.ID() } s.publishServerHeartbeatStartedEvent(connID, false) // Create a new connection and add it's handshake RTT as a sample. - err = s.setupHeartbeatConnection() - duration = time.Since(start) + err = s.setupHeartbeatConnection(ctx) + execDuration = time.Since(start) connID = "0" if s.conn != nil { connID = s.conn.ID() @@ -869,80 +983,47 @@ func (s *Server) check() (description.Server, error) { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.helloRTT) descPtr = &s.conn.desc - s.publishServerHeartbeatSucceededEvent(connID, duration, s.conn.desc, false) + s.publishServerHeartbeatSucceededEvent(connID, execDuration, s.conn.desc, false) } else { err = unwrapConnectionError(err) - s.publishServerHeartbeatFailedEvent(connID, duration, err, false) + s.publishServerHeartbeatFailedEvent(connID, execDuration, err, false) } } else { - // An existing connection is being used. Use the server description properties to execute the right heartbeat. - - // Wrap conn in a type that implements driver.StreamerConnection. - iconn := initConnection{s.conn} - heartbeatConn := mnet.NewConnection(iconn) + // An existing connection is being used. Use the server description + // properties to execute the right heartbeat. - baseOperation := s.createBaseOperation(heartbeatConn) - previousDescription := s.Description() streamable := isStreamingEnabled(s) && isStreamable(s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) - switch { - case s.conn.getCurrentlyStreaming(): - // The connection is already in a streaming state, so we stream the next response. - err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn) - case streamable: - // The server supports the streamable protocol. Set the socket timeout to - // connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so - // the wire message will advertise streaming support to the server. - - // Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13). - maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6 - // If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS + - // heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS - // server-side. - socketTimeout := s.cfg.heartbeatTimeout - if socketTimeout != 0 { - socketTimeout += s.cfg.heartbeatInterval - } - s.conn.setSocketTimeout(socketTimeout) - baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion). - MaxAwaitTimeMS(maxAwaitTimeMS) - s.conn.setCanStream(true) - err = baseOperation.Execute(s.heartbeatCtx) - default: - // The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and - // execute a regular heartbeat without any additional parameters. - - s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) - err = baseOperation.Execute(s.heartbeatCtx) - } + var tempDesc description.Server + tempDesc, err = doHandshake(ctx, s) // Perform a handshake with the server - duration = time.Since(start) + execDuration = time.Since(start) // We need to record an RTT sample in the polling case so that if the server // is < 4.4, or if polling is specified by the user, then the // RTT-short-circuit feature of CSOT is not disabled. if !streamable { - s.rttMonitor.addSample(duration) + s.rttMonitor.addSample(execDuration) } if err == nil { - tempDesc := baseOperation.Result(s.address) descPtr = &tempDesc - s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, tempDesc, s.conn.getCurrentlyStreaming() || streamable) + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), execDuration, + tempDesc, s.conn.getCurrentlyStreaming() || streamable) } else { // Close the connection here rather than below so we ensure we're not closing a connection that wasn't // successfully created. if s.conn != nil { _ = s.conn.close() } - s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, s.conn.getCurrentlyStreaming() || streamable) + s.publishServerHeartbeatFailedEvent(s.conn.ID(), execDuration, err, s.conn.getCurrentlyStreaming() || streamable) } } if descPtr != nil { - // The check was successful. Set the average RTT and the 90th percentile RTT and return. + // The check was successful. Set the average RTT and return. desc := *descPtr desc.AverageRTT = s.rttMonitor.EWMA() desc.AverageRTTSet = true @@ -951,7 +1032,7 @@ func (s *Server) check() (description.Server, error) { return desc, nil } - if s.checkWasCancelled() { + if previousCanceled { // If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller // will know that an actual error didn't occur. return emptyDescription, errCheckCancelled diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index c02600e232..bfd1218d12 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -25,7 +25,7 @@ type serverConfig struct { connectionOpts []ConnectionOption appname string heartbeatInterval time.Duration - heartbeatTimeout time.Duration + connectTimeout time.Duration serverMonitoringMode string serverMonitor *event.ServerMonitor registry *bson.Registry @@ -43,10 +43,10 @@ type serverConfig struct { poolMaintainInterval time.Duration } -func newServerConfig(opts ...ServerOption) *serverConfig { +func newServerConfig(connectTimeout time.Duration, opts ...ServerOption) *serverConfig { cfg := &serverConfig{ heartbeatInterval: 10 * time.Second, - heartbeatTimeout: 10 * time.Second, + connectTimeout: connectTimeout, registry: defaultRegistry, } @@ -65,8 +65,8 @@ type ServerOption func(*serverConfig) // ServerAPIFromServerOptions will return the server API options if they have been functionally set on the ServerOption // slice. -func ServerAPIFromServerOptions(opts []ServerOption) *driver.ServerAPIOptions { - return newServerConfig(opts...).serverAPI +func ServerAPIFromServerOptions(connectTimeout time.Duration, opts []ServerOption) *driver.ServerAPIOptions { + return newServerConfig(connectTimeout, opts...).serverAPI } func withMonitoringDisabled(fn func(bool) bool) ServerOption { @@ -103,14 +103,6 @@ func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption { } } -// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to -// connection. -func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption { - return func(cfg *serverConfig) { - cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout) - } -} - // WithMaxConnections configures the maximum number of connections to allow for // a given server. If max is 0, then maximum connection pool size is not limited. func WithMaxConnections(fn func(uint64) uint64) ServerOption { diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 1fd20fdadb..8b18d2408c 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -33,6 +33,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" @@ -166,6 +167,7 @@ func TestServerHeartbeatTimeout(t *testing.T) { server := NewServer( address.Address("localhost:27017"), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return tpm.PoolMonitor }), @@ -218,6 +220,7 @@ func TestServerConnectionTimeout(t *testing.T) { desc: "successful connection should not clear the pool", expectErr: false, expectPoolCleared: false, + connectTimeout: defaultConnectionTimeout, }, { desc: "timeout error during dialing should clear the pool", @@ -262,6 +265,7 @@ func TestServerConnectionTimeout(t *testing.T) { }, expectErr: true, expectPoolCleared: true, + connectTimeout: defaultConnectionTimeout, }, { desc: "operation context timeout with unrelated dial errors should clear the pool", @@ -300,15 +304,13 @@ func TestServerConnectionTimeout(t *testing.T) { server := NewServer( address.Address(l.Addr().String()), bson.NewObjectID(), + tc.connectTimeout, WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return tpm.PoolMonitor }), // Replace the default dialer and handshaker with the test dialer and handshaker, if // present. WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption { - if tc.connectTimeout > 0 { - opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout })) - } if tc.dialer != nil { opts = append(opts, WithDialer(tc.dialer)) } @@ -381,6 +383,7 @@ func TestServer(t *testing.T) { s := NewServer( address.Address("localhost"), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { return append(connOpts, WithHandshaker(func(Handshaker) Handshaker { @@ -567,7 +570,13 @@ func TestServer(t *testing.T) { WithMaxConnecting(func(uint64) uint64 { return 1 }), } - server, err := ConnectServer(address.Address("localhost:27017"), nil, bson.NewObjectID(), serverOpts...) + server, err := ConnectServer( + address.Address("localhost:27017"), + nil, + bson.NewObjectID(), + defaultConnectionTimeout, + serverOpts..., + ) assert.Nil(t, err, "ConnectServer error: %v", err) defer func() { _ = server.Disconnect(context.Background()) @@ -601,6 +610,7 @@ func TestServer(t *testing.T) { d := newdialer(&net.Dialer{}) s := NewServer(address.Address(addr.String()), bson.NewObjectID(), + defaultConnectionTimeout, WithConnectionOptions(func(option ...ConnectionOption) []ConnectionOption { return []ConnectionOption{WithDialer(func(_ Dialer) Dialer { return d })} }), @@ -648,7 +658,14 @@ func TestServer(t *testing.T) { updated.Store(true) return desc } - s, err := ConnectServer(address.Address("localhost"), updateCallback, bson.NewObjectID()) + + s, err := ConnectServer( + address.Address("localhost"), + updateCallback, + bson.NewObjectID(), + defaultConnectionTimeout, + ) + require.NoError(t, err) s.updateDescription(description.Server{Addr: s.address}) require.True(t, updated.Load().(bool)) @@ -663,10 +680,10 @@ func TestServer(t *testing.T) { return append(connOpts, dialerOpt) }) - s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), serverOpt) + s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), defaultConnectionTimeout, serverOpt) // do a heartbeat with a nil connection so a new one will be dialed - _, err := s.check() + _, err := s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) assert.NotNil(t, s.conn, "no connection dialed in check") @@ -683,7 +700,7 @@ func TestServer(t *testing.T) { if err = channelConn.AddResponse(makeHelloReply()); err != nil { t.Fatalf("error adding response: %v", err) } - _, err = s.check() + _, err = s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) wm = channelConn.GetWrittenMessage() @@ -727,10 +744,10 @@ func TestServer(t *testing.T) { WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }), } - s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), serverOpts...) + s := NewServer(address.Address("localhost:27017"), bson.NewObjectID(), defaultConnectionTimeout, serverOpts...) // set up heartbeat connection, which doesn't send events - _, err := s.check() + _, err := s.check(context.Background()) assert.Nil(t, err, "check error: %v", err) channelConn := s.conn.nc.(*drivertest.ChannelNetConn) @@ -742,7 +759,7 @@ func TestServer(t *testing.T) { if err = channelConn.AddResponse(makeHelloReply()); err != nil { t.Fatalf("error adding response: %v", err) } - _, err = s.check() + _, err = s.check(context.Background()) _ = channelConn.GetWrittenMessage() assert.Nil(t, err, "check error: %v", err) @@ -764,7 +781,7 @@ func TestServer(t *testing.T) { // do a heartbeat with a non-nil connection readErr := errors.New("error") channelConn.ReadErr <- readErr - _, err = s.check() + _, err = s.check(context.Background()) _ = channelConn.GetWrittenMessage() assert.Nil(t, err, "check error: %v", err) @@ -787,65 +804,10 @@ func TestServer(t *testing.T) { s := NewServer(address.Address("localhost"), bson.NewObjectID(), + defaultConnectionTimeout, WithServerAppName(func(string) string { return name })) require.Equal(t, name, s.cfg.appname, "expected appname to be: %v, got: %v", name, s.cfg.appname) }) - t.Run("createConnection overwrites WithSocketTimeout", func(t *testing.T) { - socketTimeout := 40 * time.Second - - s := NewServer( - address.Address("localhost"), - bson.NewObjectID(), - WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption { - return append( - connOpts, - WithReadTimeout(func(time.Duration) time.Duration { return socketTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return socketTimeout }), - ) - }), - ) - - conn := s.createConnection() - assert.Equal(t, s.cfg.heartbeatTimeout, 10*time.Second, "expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.heartbeatTimeout) - assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout) - assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout) - }) - t.Run("heartbeat contexts are not leaked", func(t *testing.T) { - // The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks. - - server, err := ConnectServer( - address.Address("invalid"), - nil, - bson.NewObjectID(), - withMonitoringDisabled(func(bool) bool { - return true - }), - ) - assert.Nil(t, err, "ConnectServer error: %v", err) - - // Expect check to return an error in the server description because the server address doesn't exist. This is - // OK because we just want to ensure the heartbeat context is created. - desc, err := server.check() - assert.Nil(t, err, "check error: %v", err) - assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") - assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil") - assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err()) - - // Override heartbeatCtxCancel with a wrapper that records whether or not it was called. - oldCancelFn := server.heartbeatCtxCancel - var previousCtxCancelled bool - server.heartbeatCtxCancel = func() { - previousCtxCancelled = true - oldCancelFn() - } - - // The second check call should attempt to create a new heartbeat connection and should cancel the previous - // heartbeatCtx during the process. - desc, err = server.check() - assert.Nil(t, err, "check error: %v", err) - assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil") - assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not") - }) } func TestServer_ProcessError(t *testing.T) { @@ -1188,7 +1150,7 @@ func TestServer_ProcessError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := NewServer(address.Address(""), bson.NewObjectID()) + server := NewServer(address.Address(""), bson.NewObjectID(), defaultConnectionTimeout) server.state = serverConnected err := server.pool.ready() require.Nil(t, err, "pool.ready() error: %v", err) @@ -1213,6 +1175,82 @@ func TestServer_ProcessError(t *testing.T) { } } +func TestServer_getSocketTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + enableStreaming bool + connectTimeout time.Duration + heartbeatInterval time.Duration + want time.Duration + }{ + { + name: "server is streamable with connectTimeout and no heartbeat interval", + enableStreaming: true, + connectTimeout: 1, + heartbeatInterval: 0, + want: 1, + }, + { + name: "server is streamable with connectTimeout and heartbeat interval", + enableStreaming: true, + connectTimeout: 1, + heartbeatInterval: 1, + want: 2, + }, + { + name: "server is streamable with no connectTimeout and heartbeat interval", + enableStreaming: true, + connectTimeout: 0, + heartbeatInterval: 1, + want: 0, + }, + { + name: "server is streamable with no connectTimeout and no heartbeat interval", + enableStreaming: true, + connectTimeout: 0, + heartbeatInterval: 0, + want: 0, + }, + { + name: "server is not streamable", + enableStreaming: false, + connectTimeout: 1, + heartbeatInterval: 0, + want: 1, + }, + } + + for _, test := range tests { + test := test // Capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + srv := &Server{ + cfg: &serverConfig{ + connectTimeout: test.connectTimeout, + heartbeatInterval: test.heartbeatInterval, + }, + conn: &connection{}, + } + + srv.desc.Store(description.Server{ + Kind: description.ServerKind(description.TopologyKindReplicaSet), + TopologyVersion: &description.TopologyVersion{}, + }) + + if test.enableStreaming { + srv.cfg.serverMonitoringMode = connstring.ServerMonitoringModeStream + } + + got := getHeartbeatTimeout(srv) + assert.Equal(t, test.want, got) + }) + } +} + // includesClientMetadata will return true if the wire message includes the // "client" field. func includesClientMetadata(t *testing.T, wm []byte) bool { @@ -1303,3 +1341,46 @@ func newServerDescription( LastError: lastError, } } + +type mockServerChecker struct { + sleep time.Duration +} + +var _ serverChecker = &mockServerChecker{} + +func (checker *mockServerChecker) check(ctx context.Context) (description.Server, error) { + select { + case <-ctx.Done(): + return description.Server{}, ctx.Err() + case <-time.After(checker.sleep): + } + + return description.Server{}, nil +} + +func TestCheckServerWithSignal(t *testing.T) { + t.Run("check finishes before signal", func(t *testing.T) { + listener := newNonBlockingContextDoneListener() + go func() { + defer listener.StopListening() + + time.Sleep(105 * time.Millisecond) + }() + + _, err := checkServerWithSignal(&mockServerChecker{sleep: 100 * time.Millisecond}, &connection{}, listener) + assert.NoError(t, err) + }) + + t.Run("check finishes after signal", func(t *testing.T) { + listener := newNonBlockingContextDoneListener() + go func() { + defer listener.StopListening() + + time.Sleep(100 * time.Millisecond) + }() + + _, err := checkServerWithSignal(&mockServerChecker{sleep: 1 * time.Second}, &connection{}, listener) + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) +} diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index d9e9de1f50..60077cea85 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -63,10 +63,6 @@ var ErrTopologyClosed = errors.New("topology is closed") // already connected Topology. var ErrTopologyConnected = errors.New("topology is connected or connecting") -// ErrServerSelectionTimeout is returned from server selection when the server -// selection process took longer than allowed by the timeout. -var ErrServerSelectionTimeout = errors.New("server selection timeout") - // MonitorMode represents the way in which a server is monitored. type MonitorMode uint8 @@ -126,18 +122,6 @@ var ( _ driver.Subscriber = &Topology{} ) -type serverSelectionState struct { - selector description.ServerSelector - timeoutChan <-chan time.Time -} - -func newServerSelectionState(selector description.ServerSelector, timeoutChan <-chan time.Time) serverSelectionState { - return serverSelectionState{ - selector: selector, - timeoutChan: timeoutChan, - } -} - // New creates a new topology. A "nil" config is interpreted as the default configuration. func New(cfg *Config) (*Topology, error) { if cfg == nil { @@ -503,9 +487,8 @@ func (t *Topology) RequestImmediateCheck() { t.serversLock.Unlock() } -// SelectServer selects a server with given a selector. SelectServer complies with the -// server selection spec, and will time out after serverSelectionTimeout or when the -// parent context is done. +// SelectServer selects a server with given a selector, returning the remaining +// computedServerSelectionTimeout. func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) { if atomic.LoadInt64(&t.state) != topologyConnected { if mustLogServerSelection(t, logger.LevelDebug) { @@ -514,17 +497,9 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect return nil, ErrTopologyClosed } - var ssTimeoutCh <-chan time.Time - - if t.cfg.ServerSelectionTimeout > 0 { - ssTimeout := time.NewTimer(t.cfg.ServerSelectionTimeout) - ssTimeoutCh = ssTimeout.C - defer ssTimeout.Stop() - } var doneOnce bool var sub *driver.Subscription - selectionState := newServerSelectionState(ss, ssTimeoutCh) // Record the start time. startTime := time.Now() @@ -539,7 +514,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect // for the first pass, select a server from the current description. // this improves selection speed for up-to-date topology descriptions. - suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState) + suitable, selectErr = t.selectServerFromDescription(t.Description(), ss) doneOnce = true } else { // if the first pass didn't select a server, the previous description did not contain a suitable server, so @@ -557,7 +532,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect defer func() { _ = t.Unsubscribe(sub) }() } - suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState) + suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, ss) } if selectErr != nil { if mustLogServerSelection(t, logger.LevelDebug) { @@ -704,20 +679,22 @@ func (t *Topology) FindServer(selected description.Server) (*SelectedServer, err // selectServerFromSubscription loops until a topology description is available for server selection. It returns // when the given context expires, server selection timeout is reached, or a description containing a selectable // server is available. -func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology, - selectionState serverSelectionState) ([]description.Server, error) { +func (t *Topology) selectServerFromSubscription( + ctx context.Context, + subscriptionCh <-chan description.Topology, + srvSelector description.ServerSelector, +) ([]description.Server, error) { current := t.Description() for { select { case <-ctx.Done(): return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current} - case <-selectionState.timeoutChan: - return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current} case current = <-subscriptionCh: + default: } - suitable, err := t.selectServerFromDescription(current, selectionState) + suitable, err := t.selectServerFromDescription(current, srvSelector) if err != nil { return nil, err } @@ -730,8 +707,10 @@ func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptio } // selectServerFromDescription process the given topology description and returns a slice of suitable servers. -func (t *Topology) selectServerFromDescription(desc description.Topology, - selectionState serverSelectionState) ([]description.Server, error) { +func (t *Topology) selectServerFromDescription( + desc description.Topology, + srvSelector description.ServerSelector, +) ([]description.Server, error) { // Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because // selecting a server from a description is not a blocking operation. @@ -759,7 +738,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, allowed[i] = desc.Servers[idx] } - suitable, err := selectionState.selector.SelectServer(desc, allowed) + suitable, err := srvSelector.SelectServer(desc, allowed) if err != nil { return nil, ServerSelectionError{Wrapped: err, Desc: desc} } @@ -769,7 +748,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, func (t *Topology) pollSRVRecords(hosts string) { defer t.pollingwg.Done() - serverConfig := newServerConfig(t.cfg.ServerOpts...) + serverConfig := newServerConfig(t.cfg.ConnectTimeout, t.cfg.ServerOpts...) heartbeatInterval := serverConfig.heartbeatInterval pollTicker := time.NewTicker(t.rescanSRVInterval) @@ -992,7 +971,7 @@ func (t *Topology) addServer(addr address.Address) error { return nil } - svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ServerOpts...) + svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ConnectTimeout, t.cfg.ServerOpts...) if err != nil { return err } @@ -1104,6 +1083,16 @@ func (t *Topology) publishTopologyClosedEvent() { } } +// GetServerSelectionTimeout returns the server selection timeout defined on +// the client options. +func (t *Topology) GetServerSelectionTimeout() time.Duration { + if t.cfg == nil { + return 0 + } + + return t.cfg.ServerSelectionTimeout +} + func newEventServerDescription(srv description.Server) event.ServerDescription { evtSrv := event.ServerDescription{ Addr: srv.Addr, diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index dddada9c37..612735bd3b 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -51,9 +51,8 @@ func TestTopologyErrors(t *testing.T) { selectServerCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - state := newServerSelectionState(selectNone, make(<-chan time.Time)) subCh := make(<-chan description.Topology) - _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, state) + _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, selectNone) return true } assert.Eventually(t, diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 705ec3f7e1..dd35ec80fb 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -24,6 +24,7 @@ import ( ) const defaultServerSelectionTimeout = 30 * time.Second +const defaultConnectionTimeout = 30 * time.Second // Config is used to construct a topology. type Config struct { @@ -32,6 +33,8 @@ type Config struct { SeedList []string ServerOpts []ServerOption URI string + ConnectTimeout time.Duration + Timeout *time.Duration ServerSelectionTimeout time.Duration ServerMonitor *event.ServerMonitor SRVMaxHosts int @@ -82,11 +85,16 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, var connOpts []ConnectionOption var serverOpts []ServerOption - cfgp := &Config{} + cfgp := &Config{ + Timeout: co.Timeout, + } // Set the default "ServerSelectionTimeout" to 30 seconds. cfgp.ServerSelectionTimeout = defaultServerSelectionTimeout + // Set the default "ConnectionTimeout" to 30 seconds. + cfgp.ConnectTimeout = defaultConnectionTimeout + // Set the default "SeedList" to localhost. cfgp.SeedList = []string{"localhost:27017"} @@ -204,15 +212,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, } } connOpts = append(connOpts, WithHandshaker(handshaker)) - // ConnectTimeout - if co.ConnectTimeout != nil { - serverOpts = append(serverOpts, WithHeartbeatTimeout( - func(time.Duration) time.Duration { return *co.ConnectTimeout }, - )) - connOpts = append(connOpts, WithConnectTimeout( - func(time.Duration) time.Duration { return *co.ConnectTimeout }, - )) - } + // Dialer if co.Dialer != nil { connOpts = append(connOpts, WithDialer( @@ -292,13 +292,9 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, if co.ServerSelectionTimeout != nil { cfgp.ServerSelectionTimeout = *co.ServerSelectionTimeout } - // SocketTimeout - if co.SocketTimeout != nil { - connOpts = append( - connOpts, - WithReadTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }), - WithWriteTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }), - ) + //ConnectionTimeout + if co.ConnectTimeout != nil { + cfgp.ConnectTimeout = *co.ConnectTimeout } // TLSConfig if co.TLSConfig != nil { diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index e57c75bcb0..1b6140f5dc 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -73,7 +73,7 @@ func TestLoadBalancedFromConnString(t *testing.T) { assert.Nil(t, err, "topology.New error: %v", err) assert.Equal(t, tc.loadBalanced, topo.cfg.LoadBalanced, "expected loadBalanced %v, got %v", tc.loadBalanced, topo.cfg.LoadBalanced) - srvr := NewServer("", topo.id, topo.cfg.ServerOpts...) + srvr := NewServer("", topo.id, defaultConnectionTimeout, topo.cfg.ServerOpts...) assert.Equal(t, tc.loadBalanced, srvr.cfg.loadBalanced, "expected loadBalanced %v, got %v", tc.loadBalanced, srvr.cfg.loadBalanced) conn := newConnection("", srvr.cfg.connectionOpts...) diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 937824d4dd..0e4920d88b 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -9,7 +9,6 @@ package topology import ( "context" "encoding/json" - "errors" "fmt" "io/ioutil" "path" @@ -25,7 +24,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) @@ -65,10 +63,6 @@ func TestServerSelection(t *testing.T) { var selectNone serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { return []description.Server{}, nil } - var errSelectionError = errors.New("encountered an error in the selector") - var selectError serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { - return nil, errSelectionError - } t.Run("Success", func(t *testing.T) { topo, err := New(nil) @@ -83,8 +77,7 @@ func TestServerSelection(t *testing.T) { subCh := make(chan description.Topology, 1) subCh <- desc - state := newServerSelectionState(selectFirst, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) + srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) noerr(t, err) if len(srvs) != 1 { t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) @@ -148,8 +141,7 @@ func TestServerSelection(t *testing.T) { resp := make(chan []description.Server) go func() { - state := newServerSelectionState(selectFirst, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) + srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) noerr(t, err) resp <- srvs }() @@ -196,8 +188,7 @@ func TestServerSelection(t *testing.T) { resp := make(chan error) ctx, cancel := context.WithCancel(context.Background()) go func() { - state := newServerSelectionState(selectNone, nil) - _, err := topo.selectServerFromSubscription(ctx, subCh, state) + _, err := topo.selectServerFromSubscription(ctx, subCh, selectNone) resp <- err }() @@ -218,77 +209,11 @@ func TestServerSelection(t *testing.T) { want := ServerSelectionError{Wrapped: context.Canceled, Desc: desc} assert.Equal(t, err, want, "Incorrect error received. got %v; want %v", err, want) }) - t.Run("Timeout", func(t *testing.T) { - desc := description.Topology{ - Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, - }, - } - topo, err := New(nil) - noerr(t, err) - subCh := make(chan description.Topology, 1) - subCh <- desc - resp := make(chan error) - timeout := make(chan time.Time) - go func() { - state := newServerSelectionState(selectNone, timeout) - _, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - resp <- err - }() - - select { - case err := <-resp: - t.Errorf("Received error from server selection too soon: %v", err) - case timeout <- time.Now(): - } - - select { - case err = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if err == nil { - t.Fatalf("did not receive error from server selection") - } - }) - t.Run("Error", func(t *testing.T) { - desc := description.Topology{ - Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, - {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, - }, - } - topo, err := New(nil) - noerr(t, err) - subCh := make(chan description.Topology, 1) - subCh <- desc - resp := make(chan error) - timeout := make(chan time.Time) - go func() { - state := newServerSelectionState(selectError, timeout) - _, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - resp <- err - }() - - select { - case err = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if err == nil { - t.Fatalf("did not receive error from server selection") - } - }) t.Run("findServer returns topology kind", func(t *testing.T) { topo, err := New(nil) noerr(t, err) atomic.StoreInt64(&topo.state, topologyConnected) - srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id) + srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id, defaultConnectionTimeout) noerr(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) @@ -303,71 +228,6 @@ func TestServerSelection(t *testing.T) { t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.TopologyKindSingle) } }) - t.Run("Update on not primary error", func(t *testing.T) { - topo, err := New(nil) - noerr(t, err) - atomic.StoreInt64(&topo.state, topologyConnected) - - addr1 := address.Address("one") - addr2 := address.Address("two") - addr3 := address.Address("three") - desc := description.Topology{ - Servers: []description.Server{ - {Addr: addr1, Kind: description.ServerKindRSPrimary}, - {Addr: addr2, Kind: description.ServerKindRSSecondary}, - {Addr: addr3, Kind: description.ServerKindRSSecondary}, - }, - } - - // manually add the servers to the topology - for _, srv := range desc.Servers { - s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) - topo.servers[srv.Addr] = s - } - - // Send updated description - desc = description.Topology{ - Servers: []description.Server{ - {Addr: addr1, Kind: description.ServerKindRSSecondary}, - {Addr: addr2, Kind: description.ServerKindRSPrimary}, - {Addr: addr3, Kind: description.ServerKindRSSecondary}, - }, - } - - subCh := make(chan description.Topology, 1) - subCh <- desc - - // send a not primary error to the server forcing an update - serv, err := topo.FindServer(desc.Servers[0]) - noerr(t, err) - atomic.StoreInt64(&serv.state, serverConnected) - _ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{}) - - resp := make(chan []description.Server) - - go func() { - // server selection should discover the new topology - state := newServerSelectionState(&serverselector.Write{}, nil) - srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) - resp <- srvs - }() - - var srvs []description.Server - select { - case srvs = <-resp: - case <-time.After(100 * time.Millisecond): - t.Errorf("Timed out while trying to retrieve selected servers") - } - - if len(srvs) != 1 { - t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) - } - if srvs[0].Addr != desc.Servers[1].Addr { - t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[1].Addr) - } - }) t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) { // Assert that the server selection fast path does not create a Subscription or check for timeout errors. topo, err := New(nil) @@ -382,7 +242,7 @@ func TestServerSelection(t *testing.T) { } topo.desc.Store(desc) for _, srv := range desc.Servers { - s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) + s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id, defaultConnectionTimeout) noerr(t, err) topo.servers[srv.Addr] = s } @@ -983,6 +843,7 @@ func runInWindowTest(t *testing.T, directory string, filename string) { server := NewServer( address.Address(testDesc.Address), bson.NilObjectID, + defaultConnectionTimeout, withMonitoringDisabled(func(bool) bool { return true })) servers[testDesc.Address] = server @@ -1176,13 +1037,12 @@ func BenchmarkSelectServerFromDescription(b *testing.B) { Servers: servers, } - timeout := make(chan time.Time) b.ResetTimer() b.RunParallel(func(p *testing.PB) { b.ReportAllocs() for p.Next() { var c Topology - _, _ = c.selectServerFromDescription(desc, newServerSelectionState(selectNone, timeout)) + _, _ = c.selectServerFromDescription(desc, selectNone) } }) })