diff --git a/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs b/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs index 07612031c97..dd4b9ca0688 100644 --- a/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs +++ b/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs @@ -58,7 +58,8 @@ public void Authenticate(IConnection connection, ConnectionDescription descripti try { var protocol = CreateAuthenticateProtocol(); - protocol.Execute(connection, cancellationToken); + // TODO: CSOT: implement operationContext support for Auth. + protocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); } catch (MongoCommandException ex) { @@ -79,7 +80,8 @@ public async Task AuthenticateAsync(IConnection connection, ConnectionDescriptio try { var protocol = CreateAuthenticateProtocol(); - await protocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: implement operationContext support for Auth. + await protocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); } catch (MongoCommandException ex) { diff --git a/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs b/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs index fddb5953b60..d42558ddee6 100644 --- a/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs +++ b/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs @@ -109,7 +109,8 @@ public void Authenticate(IConnection connection, ConnectionDescription descripti try { var protocol = CreateCommandProtocol(command); - result = protocol.Execute(connection, cancellationToken); + // TODO: CSOT: implement operationContext support for Auth. + result = protocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); conversationId ??= result?.GetValue("conversationId").AsInt32; } catch (MongoException ex) @@ -172,7 +173,8 @@ public async Task AuthenticateAsync(IConnection connection, ConnectionDescriptio try { var protocol = CreateCommandProtocol(command); - result = await protocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: implement operationContext support for Auth. + result = await protocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); conversationId ??= result?.GetValue("conversationId").AsInt32; } catch (MongoException ex) diff --git a/src/MongoDB.Driver/Core/Bindings/IChannel.cs b/src/MongoDB.Driver/Core/Bindings/IChannel.cs index 275d6cdebbc..3a1dca12020 100644 --- a/src/MongoDB.Driver/Core/Bindings/IChannel.cs +++ b/src/MongoDB.Driver/Core/Bindings/IChannel.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -33,6 +32,7 @@ internal interface IChannel : IDisposable ConnectionDescription ConnectionDescription { get; } TResult Command( + OperationContext operationContext, ICoreSession session, ReadPreference readPreference, DatabaseNamespace databaseNamespace, @@ -43,10 +43,10 @@ TResult Command( Action postWriteAction, CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken); + MessageEncoderSettings messageEncoderSettings); Task CommandAsync( + OperationContext operationContext, ICoreSession session, ReadPreference readPreference, DatabaseNamespace databaseNamespace, @@ -57,8 +57,7 @@ Task CommandAsync( Action postWriteAction, CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken); + MessageEncoderSettings messageEncoderSettings); } internal interface IChannelHandle : IChannel diff --git a/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs b/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs index c5fbc55cea1..bd8080404ca 100644 --- a/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs +++ b/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs @@ -35,20 +35,11 @@ public ServerChannelSource(IServer server, ICoreSessionHandle session) } // properties - public IServer Server - { - get { return _server; } - } + public IServer Server => _server; - public ServerDescription ServerDescription - { - get { return _server.Description; } - } + public ServerDescription ServerDescription => _server.Description; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; // methods public void Dispose() diff --git a/src/MongoDB.Driver/Core/Clusters/Cluster.cs b/src/MongoDB.Driver/Core/Clusters/Cluster.cs index 287a40657aa..83591f43ea3 100644 --- a/src/MongoDB.Driver/Core/Clusters/Cluster.cs +++ b/src/MongoDB.Driver/Core/Clusters/Cluster.cs @@ -159,7 +159,7 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposedOrNotOpen(); - operationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); var expirableClusterDescription = _expirableClusterDescription; IDisposable serverSelectionWaitQueueDisposer = null; (selector, var operationCountSelector, var stopwatch) = BeginServerSelection(expirableClusterDescription.ClusterDescription, selector); @@ -168,16 +168,16 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s { while (true) { - var result = SelectServer(expirableClusterDescription, selector, operationCountSelector); - if (result != default) + var server = SelectServer(expirableClusterDescription, selector, operationCountSelector); + if (server != null) { - EndServerSelection(expirableClusterDescription.ClusterDescription, selector, result.ServerDescription, stopwatch); - return result.Server; + EndServerSelection(expirableClusterDescription.ClusterDescription, selector, server.Description, stopwatch); + return server; } - serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(operationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); + serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(serverSelectionOperationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); - operationContext.WaitTask(expirableClusterDescription.Expired); + serverSelectionOperationContext.WaitTask(expirableClusterDescription.Expired); expirableClusterDescription = _expirableClusterDescription; } } @@ -197,7 +197,7 @@ public async Task SelectServerAsync(OperationContext operationContext, Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposedOrNotOpen(); - operationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); var expirableClusterDescription = _expirableClusterDescription; IDisposable serverSelectionWaitQueueDisposer = null; (selector, var operationCountSelector, var stopwatch) = BeginServerSelection(expirableClusterDescription.ClusterDescription, selector); @@ -206,16 +206,16 @@ public async Task SelectServerAsync(OperationContext operationContext, { while (true) { - var result = SelectServer(expirableClusterDescription, selector, operationCountSelector); - if (result != default) + var server = SelectServer(expirableClusterDescription, selector, operationCountSelector); + if (server != null) { - EndServerSelection(expirableClusterDescription.ClusterDescription, selector, result.ServerDescription, stopwatch); - return result.Server; + EndServerSelection(expirableClusterDescription.ClusterDescription, selector, server.Description, stopwatch); + return server; } - serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(operationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); + serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(serverSelectionOperationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); - await operationContext.WaitTaskAsync(expirableClusterDescription.Expired).ConfigureAwait(false); + await serverSelectionOperationContext.WaitTaskAsync(expirableClusterDescription.Expired).ConfigureAwait(false); expirableClusterDescription = _expirableClusterDescription; } } @@ -306,7 +306,7 @@ private Exception HandleServerSelectionException(ClusterDescription clusterDescr return exception; } - private (IClusterableServer Server, ServerDescription ServerDescription) SelectServer(ExpirableClusterDescription clusterDescriptionChangeSource, IServerSelector selector, OperationsCountServerSelector operationCountSelector) + private SelectedServer SelectServer(ExpirableClusterDescription clusterDescriptionChangeSource, IServerSelector selector, OperationsCountServerSelector operationCountSelector) { MongoIncompatibleDriverException.ThrowIfNotSupported(clusterDescriptionChangeSource.ClusterDescription); @@ -320,7 +320,7 @@ private Exception HandleServerSelectionException(ClusterDescription clusterDescr var selectedServer = clusterDescriptionChangeSource.ConnectedServers.FirstOrDefault(s => EndPointHelper.Equals(s.EndPoint, selectedServerDescription.EndPoint)); if (selectedServer != null) { - return (selectedServer, selectedServerDescription); + return new(selectedServer, selectedServerDescription); } } diff --git a/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs b/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs index c77d2d45241..7fa3b29b038 100644 --- a/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs +++ b/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs @@ -176,7 +176,7 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposed(); - var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); _serverSelectionEventLogger.LogAndPublish(new ClusterSelectingServerEvent( _description, @@ -205,10 +205,11 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s stopwatch.Elapsed, null, EventContext.OperationName)); + + return new SelectedServer(_server, _server.Description); } - return _server ?? - throw new InvalidOperationException("The server must be created before usage."); // should not be reached + throw new InvalidOperationException("The server must be created before usage."); // should not be reached } public async Task SelectServerAsync(OperationContext operationContext, IServerSelector selector) @@ -217,7 +218,7 @@ public async Task SelectServerAsync(OperationContext operationContext, Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposed(); - var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); _serverSelectionEventLogger.LogAndPublish(new ClusterSelectingServerEvent( _description, @@ -245,10 +246,11 @@ public async Task SelectServerAsync(OperationContext operationContext, stopwatch.Elapsed, null, EventContext.OperationName)); + + return new SelectedServer(_server, _server.Description); } - return _server ?? - throw new InvalidOperationException("The server must be created before usage."); // should not be reached + throw new InvalidOperationException("The server must be created before usage."); // should not be reached } public ICoreSessionHandle StartSession(CoreSessionOptions options = null) diff --git a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs index 7419244a46c..44542d88da8 100644 --- a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs +++ b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs @@ -15,7 +15,6 @@ using Snappier; using System.IO; -using System.Threading; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Compression @@ -34,7 +33,7 @@ public void Compress(Stream input, Stream output) { var uncompressedSize = (int)(input.Length - input.Position); var uncompressedBytes = new byte[uncompressedSize]; // does not include uncompressed message headers - input.ReadBytes(uncompressedBytes, offset: 0, count: uncompressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); + input.ReadBytes(OperationContext.NoTimeout, uncompressedBytes, offset: 0, count: uncompressedSize); var maxCompressedSize = Snappy.GetMaxCompressedLength(uncompressedSize); var compressedBytes = new byte[maxCompressedSize]; var compressedSize = Snappy.Compress(uncompressedBytes, compressedBytes); @@ -50,7 +49,7 @@ public void Decompress(Stream input, Stream output) { var compressedSize = (int)(input.Length - input.Position); var compressedBytes = new byte[compressedSize]; - input.ReadBytes(compressedBytes, offset: 0, count: compressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); + input.ReadBytes(OperationContext.NoTimeout, compressedBytes, offset: 0, count: compressedSize); var uncompressedSize = Snappy.GetUncompressedLength(compressedBytes); var decompressedBytes = new byte[uncompressedSize]; var decompressedSize = Snappy.Decompress(compressedBytes, decompressedBytes); diff --git a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs index 5ff4a0f0845..da18a88012d 100644 --- a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs +++ b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs @@ -401,11 +401,11 @@ public void Dispose() } } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { try { - _connection.Open(cancellationToken); + _connection.Open(operationContext); SetEffectiveGenerationIfRequired(_connection.Description); } catch (MongoConnectionException ex) @@ -416,11 +416,11 @@ public void Open(CancellationToken cancellationToken) } } - public async Task OpenAsync(CancellationToken cancellationToken) + public async Task OpenAsync(OperationContext operationContext) { try { - await _connection.OpenAsync(cancellationToken).ConfigureAwait(false); + await _connection.OpenAsync(operationContext).ConfigureAwait(false); SetEffectiveGenerationIfRequired(_connection.Description); } catch (MongoConnectionException ex) @@ -435,11 +435,11 @@ public async Task OpenAsync(CancellationToken cancellationToken) public Task ReauthenticateAsync(CancellationToken cancellationToken) => _connection.ReauthenticateAsync(cancellationToken); - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { try { - return _connection.ReceiveMessage(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _connection.ReceiveMessage(operationContext, responseTo, encoderSelector, messageEncoderSettings); } catch (MongoConnectionException ex) { @@ -448,11 +448,11 @@ public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector en } } - public async Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { try { - return await _connection.ReceiveMessageAsync(responseTo, encoderSelector, messageEncoderSettings, cancellationToken).ConfigureAwait(false); + return await _connection.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, messageEncoderSettings).ConfigureAwait(false); } catch (MongoConnectionException ex) { @@ -461,11 +461,11 @@ public async Task ReceiveMessageAsync(int responseTo, IMessageE } } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { try { - _connection.SendMessage(message, messageEncoderSettings, cancellationToken); + _connection.SendMessage(operationContext, message, messageEncoderSettings); } catch (MongoConnectionException ex) { @@ -474,11 +474,11 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn } } - public async Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { try { - await _connection.SendMessageAsync(message, messageEncoderSettings, cancellationToken).ConfigureAwait(false); + await _connection.SendMessageAsync(operationContext, message, messageEncoderSettings).ConfigureAwait(false); } catch (MongoConnectionException ex) { @@ -587,16 +587,16 @@ public IConnectionHandle Fork() return new AcquiredConnection(_connectionPool, _reference); } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { ThrowIfDisposed(); - _reference.Instance.Open(cancellationToken); + _reference.Instance.Open(operationContext); } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { ThrowIfDisposed(); - return _reference.Instance.OpenAsync(cancellationToken); + return _reference.Instance.OpenAsync(operationContext); } public void Reauthenticate(CancellationToken cancellationToken) @@ -611,28 +611,28 @@ public Task ReauthenticateAsync(CancellationToken cancellationToken) return _reference.Instance.ReauthenticateAsync(cancellationToken); } - public Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.ReceiveMessageAsync(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _reference.Instance.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, messageEncoderSettings); } - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.ReceiveMessage(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _reference.Instance.ReceiveMessage(operationContext, responseTo, encoderSelector, messageEncoderSettings); } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - _reference.Instance.SendMessage(message, messageEncoderSettings, cancellationToken); + _reference.Instance.SendMessage(operationContext, message, messageEncoderSettings); } - public Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.SendMessageAsync(message, messageEncoderSettings, cancellationToken); + return _reference.Instance.SendMessageAsync(operationContext, message, messageEncoderSettings); } public void SetCheckOutReasonIfNotAlreadySet(CheckOutReason reason) @@ -974,8 +974,7 @@ private PooledConnection CreateOpenedInternal(OperationContext operationContext) { var stopwatch = StartCreating(operationContext); - // TODO: CSOT add support of CSOT timeout in connection open code too. - _connection.Open(operationContext.CancellationToken); + _connection.Open(operationContext); FinishCreating(_connection.Description, stopwatch); @@ -986,8 +985,7 @@ private async Task CreateOpenedInternalAsync(OperationContext { var stopwatch = StartCreating(operationContext); - // TODO: CSOT add support of CSOT timeout in connection open code too. - await _connection.OpenAsync(operationContext.CancellationToken).ConfigureAwait(false); + await _connection.OpenAsync(operationContext).ConfigureAwait(false); FinishCreating(_connection.Description, stopwatch); diff --git a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs index 7489d714081..c22fcb2f431 100644 --- a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs +++ b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs @@ -141,16 +141,16 @@ public int UsedCount // public methods public IConnectionHandle AcquireConnection(OperationContext operationContext) { - operationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); + using var waitQueueOperationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); using var helper = new AcquireConnectionHelper(this); - return helper.AcquireConnection(operationContext); + return helper.AcquireConnection(waitQueueOperationContext); } public async Task AcquireConnectionAsync(OperationContext operationContext) { - operationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); + using var waitQueueOperationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); using var helper = new AcquireConnectionHelper(this); - return await helper.AcquireConnectionAsync(operationContext).ConfigureAwait(false); + return await helper.AcquireConnectionAsync(waitQueueOperationContext).ConfigureAwait(false); } public void Clear(bool closeInUseConnections = false) diff --git a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs index 210e33cc14c..2c49c91891c 100644 --- a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs +++ b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs @@ -62,7 +62,8 @@ internal sealed class BinaryConnection : IConnection private readonly EventLogger _eventLogger; // constructors - public BinaryConnection(ServerId serverId, + public BinaryConnection( + ServerId serverId, EndPoint endPoint, ConnectionSettings settings, IStreamFactory streamFactory, @@ -203,9 +204,9 @@ private void EnsureMessageSizeIsValid(int messageSize) } } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); TaskCompletionSource taskCompletionSource = null; var connecting = false; @@ -225,7 +226,7 @@ public void Open(CancellationToken cancellationToken) { try { - OpenHelper(cancellationToken); + OpenHelper(operationContext); taskCompletionSource.TrySetResult(true); } catch (Exception ex) @@ -240,33 +241,37 @@ public void Open(CancellationToken cancellationToken) } } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); lock (_openLock) { if (_state.TryChange(State.Initial, State.Connecting)) { _openedAtUtc = DateTime.UtcNow; - _openTask = OpenHelperAsync(cancellationToken); + _openTask = OpenHelperAsync(operationContext); } return _openTask; } } - private void OpenHelper(CancellationToken cancellationToken) + private void OpenHelper(OperationContext operationContext) { var helper = new OpenConnectionHelper(this); ConnectionDescription handshakeDescription = null; try { helper.OpeningConnection(); - _stream = _streamFactory.CreateStream(_endPoint, cancellationToken); +#pragma warning disable CS0618 // Type or member is obsolete + _stream = _streamFactory.CreateStream(_endPoint, operationContext.CombinedCancellationToken); +#pragma warning restore CS0618 // Type or member is obsolete helper.InitializingConnection(); - _connectionInitializerContext = _connectionInitializer.SendHello(this, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + _connectionInitializerContext = _connectionInitializer.SendHello(this, operationContext.CancellationToken); handshakeDescription = _connectionInitializerContext.Description; - _connectionInitializerContext = _connectionInitializer.Authenticate(this, _connectionInitializerContext, cancellationToken); + // TODO: CSOT: Implement operation context support for Auth + _connectionInitializerContext = _connectionInitializer.Authenticate(this, _connectionInitializerContext, operationContext.CancellationToken); _description = _connectionInitializerContext.Description; _sendCompressorType = ChooseSendCompressorTypeIfAny(_description); @@ -281,18 +286,22 @@ private void OpenHelper(CancellationToken cancellationToken) } } - private async Task OpenHelperAsync(CancellationToken cancellationToken) + private async Task OpenHelperAsync(OperationContext operationContext) { var helper = new OpenConnectionHelper(this); ConnectionDescription handshakeDescription = null; try { helper.OpeningConnection(); - _stream = await _streamFactory.CreateStreamAsync(_endPoint, cancellationToken).ConfigureAwait(false); +#pragma warning disable CS0618 // Type or member is obsolete + _stream = await _streamFactory.CreateStreamAsync(_endPoint, operationContext.CombinedCancellationToken).ConfigureAwait(false); +#pragma warning restore CS0618 // Type or member is obsolete helper.InitializingConnection(); - _connectionInitializerContext = await _connectionInitializer.SendHelloAsync(this, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + _connectionInitializerContext = await _connectionInitializer.SendHelloAsync(this, operationContext.CancellationToken).ConfigureAwait(false); handshakeDescription = _connectionInitializerContext.Description; - _connectionInitializerContext = await _connectionInitializer.AuthenticateAsync(this, _connectionInitializerContext, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for Auth + _connectionInitializerContext = await _connectionInitializer.AuthenticateAsync(this, _connectionInitializerContext, operationContext.CancellationToken).ConfigureAwait(false); _description = _connectionInitializerContext.Description; _sendCompressorType = ChooseSendCompressorTypeIfAny(_description); helper.OpenedConnection(); @@ -326,20 +335,19 @@ private void InvalidateAuthenticator() } } - private IByteBuffer ReceiveBuffer(CancellationToken cancellationToken) + private IByteBuffer ReceiveBuffer(OperationContext operationContext) { try { var messageSizeBytes = new byte[4]; - var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan; - _stream.ReadBytes(messageSizeBytes, 0, 4, readTimeout, cancellationToken); + _stream.ReadBytes(operationContext, messageSizeBytes, 0, 4); var messageSize = BinaryPrimitives.ReadInt32LittleEndian(messageSizeBytes); EnsureMessageSizeIsValid(messageSize); var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default); var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize); buffer.Length = messageSize; buffer.SetBytes(0, messageSizeBytes, 0, 4); - _stream.ReadBytes(buffer, 4, messageSize - 4, readTimeout, cancellationToken); + _stream.ReadBytes(operationContext, buffer, 4, messageSize - 4); _lastUsedAtUtc = DateTime.UtcNow; buffer.MakeReadOnly(); return buffer; @@ -352,9 +360,9 @@ private IByteBuffer ReceiveBuffer(CancellationToken cancellationToken) } } - private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellationToken) + private IByteBuffer ReceiveBuffer(OperationContext operationContext, int responseTo) { - using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, cancellationToken)) + using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, operationContext.RemainingTimeout, operationContext.CancellationToken)) { var messageTask = _dropbox.GetMessageAsync(responseTo); try @@ -370,7 +378,7 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation { try { - var buffer = ReceiveBuffer(cancellationToken); + var buffer = ReceiveBuffer(operationContext); _dropbox.AddMessage(buffer); } catch (Exception ex) @@ -383,7 +391,7 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); } } catch @@ -396,20 +404,19 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation } } - private async Task ReceiveBufferAsync(CancellationToken cancellationToken) + private async Task ReceiveBufferAsync(OperationContext operationContext) { try { var messageSizeBytes = new byte[4]; - var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan; - await _stream.ReadBytesAsync(messageSizeBytes, 0, 4, readTimeout, cancellationToken).ConfigureAwait(false); + await _stream.ReadBytesAsync(operationContext, messageSizeBytes, 0, 4).ConfigureAwait(false); var messageSize = BinaryPrimitives.ReadInt32LittleEndian(messageSizeBytes); EnsureMessageSizeIsValid(messageSize); var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default); var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize); buffer.Length = messageSize; buffer.SetBytes(0, messageSizeBytes, 0, 4); - await _stream.ReadBytesAsync(buffer, 4, messageSize - 4, readTimeout, cancellationToken).ConfigureAwait(false); + await _stream.ReadBytesAsync(operationContext, buffer, 4, messageSize - 4).ConfigureAwait(false); _lastUsedAtUtc = DateTime.UtcNow; buffer.MakeReadOnly(); return buffer; @@ -422,9 +429,9 @@ private async Task ReceiveBufferAsync(CancellationToken cancellatio } } - private async Task ReceiveBufferAsync(int responseTo, CancellationToken cancellationToken) + private async Task ReceiveBufferAsync(OperationContext operationContext, int responseTo) { - using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, cancellationToken)) + using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, operationContext.RemainingTimeout, operationContext.CancellationToken)) { var messageTask = _dropbox.GetMessageAsync(responseTo); try @@ -435,12 +442,12 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - receiveLockRequest.Task.GetAwaiter().GetResult(); // propagate exceptions + await receiveLockRequest.Task.ConfigureAwait(false); // propagate exceptions while (true) { try { - var buffer = await ReceiveBufferAsync(cancellationToken).ConfigureAwait(false); + var buffer = await ReceiveBufferAsync(operationContext).ConfigureAwait(false); _dropbox.AddMessage(buffer); } catch (Exception ex) @@ -453,7 +460,7 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); } } catch @@ -467,21 +474,21 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT } public ResponseMessage ReceiveMessage( + OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) + MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(encoderSelector, nameof(encoderSelector)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new ReceiveMessageHelper(this, responseTo, messageEncoderSettings, _compressorSource); try { helper.ReceivingMessage(); - using (var buffer = ReceiveBuffer(responseTo, cancellationToken)) + using (var buffer = ReceiveBuffer(operationContext, responseTo)) { - var message = helper.DecodeMessage(buffer, encoderSelector, cancellationToken); + var message = helper.DecodeMessage(operationContext, buffer, encoderSelector); helper.ReceivedMessage(buffer, message); return message; } @@ -494,22 +501,20 @@ public ResponseMessage ReceiveMessage( } } - public async Task ReceiveMessageAsync( - int responseTo, + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) + MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(encoderSelector, nameof(encoderSelector)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new ReceiveMessageHelper(this, responseTo, messageEncoderSettings, _compressorSource); try { helper.ReceivingMessage(); - using (var buffer = await ReceiveBufferAsync(responseTo, cancellationToken).ConfigureAwait(false)) + using (var buffer = await ReceiveBufferAsync(operationContext, responseTo).ConfigureAwait(false)) { - var message = helper.DecodeMessage(buffer, encoderSelector, cancellationToken); + var message = helper.DecodeMessage(operationContext, buffer, encoderSelector); helper.ReceivedMessage(buffer, message); return message; } @@ -522,9 +527,9 @@ public async Task ReceiveMessageAsync( } } - private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) + private void SendBuffer(OperationContext operationContext, IByteBuffer buffer) { - _sendLock.Wait(cancellationToken); + _sendLock.Wait(operationContext.RemainingTimeout, operationContext.CancellationToken); try { if (_state.Value == State.Failed) @@ -534,8 +539,7 @@ private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) try { - var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan; - _stream.WriteBytes(buffer, 0, buffer.Length, writeTimeout, cancellationToken); + _stream.WriteBytes(operationContext, buffer, 0, buffer.Length); _lastUsedAtUtc = DateTime.UtcNow; } catch (Exception ex) @@ -551,9 +555,9 @@ private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) } } - private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancellationToken) + private async Task SendBufferAsync(OperationContext operationContext, IByteBuffer buffer) { - await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); + await _sendLock.WaitAsync(operationContext.RemainingTimeout, operationContext.CancellationToken).ConfigureAwait(false); try { if (_state.Value == State.Failed) @@ -563,8 +567,7 @@ private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancell try { - var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan; - await _stream.WriteBytesAsync(buffer, 0, buffer.Length, writeTimeout, cancellationToken).ConfigureAwait(false); + await _stream.WriteBytesAsync(operationContext, buffer, 0, buffer.Length).ConfigureAwait(false); _lastUsedAtUtc = DateTime.UtcNow; } catch (Exception ex) @@ -580,16 +583,16 @@ private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancell } } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(message, nameof(message)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new SendMessageHelper(this, message, messageEncoderSettings); try { helper.EncodingMessage(); - using (var uncompressedBuffer = helper.EncodeMessage(cancellationToken, out var sentMessage)) + using (var uncompressedBuffer = helper.EncodeMessage(operationContext, out var sentMessage)) { helper.SendingMessage(uncompressedBuffer); int sentLength; @@ -597,13 +600,13 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn { using (var compressedBuffer = CompressMessage(sentMessage, uncompressedBuffer, messageEncoderSettings)) { - SendBuffer(compressedBuffer, cancellationToken); + SendBuffer(operationContext, compressedBuffer); sentLength = compressedBuffer.Length; } } else { - SendBuffer(uncompressedBuffer, cancellationToken); + SendBuffer(operationContext, uncompressedBuffer); sentLength = uncompressedBuffer.Length; } helper.SentMessage(sentLength); @@ -617,16 +620,16 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn } } - public async Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(message, nameof(message)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new SendMessageHelper(this, message, messageEncoderSettings); try { helper.EncodingMessage(); - using (var uncompressedBuffer = helper.EncodeMessage(cancellationToken, out var sentMessage)) + using (var uncompressedBuffer = helper.EncodeMessage(operationContext, out var sentMessage)) { helper.SendingMessage(uncompressedBuffer); int sentLength; @@ -634,13 +637,13 @@ public async Task SendMessageAsync(RequestMessage message, MessageEncoderSetting { using (var compressedBuffer = CompressMessage(sentMessage, uncompressedBuffer, messageEncoderSettings)) { - await SendBufferAsync(compressedBuffer, cancellationToken).ConfigureAwait(false); + await SendBufferAsync(operationContext, compressedBuffer).ConfigureAwait(false); sentLength = compressedBuffer.Length; } } else { - await SendBufferAsync(uncompressedBuffer, cancellationToken).ConfigureAwait(false); + await SendBufferAsync(operationContext, uncompressedBuffer).ConfigureAwait(false); sentLength = uncompressedBuffer.Length; } helper.SentMessage(sentLength); @@ -717,15 +720,15 @@ private void CompressMessage( compressedMessageEncoder.WriteMessage(compressedMessage); } - private void ThrowIfCancelledOrDisposed(CancellationToken cancellationToken = default) + private void ThrowIfCancelledOrDisposed(OperationContext operationContext) { - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); ThrowIfDisposed(); } - private void ThrowIfCancelledOrDisposedOrNotOpen(CancellationToken cancellationToken) + private void ThrowIfCancelledOrDisposedOrNotOpen(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); if (_state.Value == State.Failed) { throw new MongoConnectionClosedException(_connectionId); @@ -905,9 +908,9 @@ public ReceiveMessageHelper(BinaryConnection connection, int responseTo, Message _messageEncoderSettings = messageEncoderSettings; } - public ResponseMessage DecodeMessage(IByteBuffer buffer, IMessageEncoderSelector encoderSelector, CancellationToken cancellationToken) + public ResponseMessage DecodeMessage(OperationContext operationContext, IByteBuffer buffer, IMessageEncoderSelector encoderSelector) { - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); _stopwatch.Stop(); _networkDuration = _stopwatch.Elapsed; @@ -992,10 +995,10 @@ public SendMessageHelper(BinaryConnection connection, RequestMessage message, Me _commandStopwatch = Stopwatch.StartNew(); } - public IByteBuffer EncodeMessage(CancellationToken cancellationToken, out RequestMessage sentMessage) + public IByteBuffer EncodeMessage(OperationContext operationContext, out RequestMessage sentMessage) { sentMessage = null; - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); var serializationStopwatch = Stopwatch.StartNew(); var outputBufferChunkSource = new OutputBufferChunkSource(BsonChunkPool.Default); @@ -1012,7 +1015,7 @@ public IByteBuffer EncodeMessage(CancellationToken cancellationToken, out Reques // Encoding messages includes serializing the // documents, so encoding message could be expensive // and worthy of us honoring cancellation here. - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); buffer.Length = (int)stream.Length; buffer.MakeReadOnly(); diff --git a/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs b/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs index 851fd96b82f..fdf95dbfe05 100644 --- a/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs +++ b/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs @@ -68,7 +68,8 @@ public ConnectionInitializerContext Authenticate(IConnection connection, Connect try { var getLastErrorProtocol = CreateGetLastErrorProtocol(_serverApi); - var getLastErrorResult = getLastErrorProtocol.Execute(connection, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var getLastErrorResult = getLastErrorProtocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); description = UpdateConnectionIdWithServerValue(description, getLastErrorResult); } @@ -103,8 +104,9 @@ public async Task AuthenticateAsync(IConnection co try { var getLastErrorProtocol = CreateGetLastErrorProtocol(_serverApi); + // TODO: CSOT: Implement operation context support for MongoDB Handshake var getLastErrorResult = await getLastErrorProtocol - .ExecuteAsync(connection, cancellationToken) + .ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection) .ConfigureAwait(false); description = UpdateConnectionIdWithServerValue(description, getLastErrorResult); diff --git a/src/MongoDB.Driver/Core/Connections/HelloHelper.cs b/src/MongoDB.Driver/Core/Connections/HelloHelper.cs index 70194498f5c..2ebebe12078 100644 --- a/src/MongoDB.Driver/Core/Connections/HelloHelper.cs +++ b/src/MongoDB.Driver/Core/Connections/HelloHelper.cs @@ -90,7 +90,8 @@ internal static HelloResult GetResult( { try { - var helloResultDocument = helloProtocol.Execute(connection, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var helloResultDocument = helloProtocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); return new HelloResult(helloResultDocument); } catch (MongoCommandException ex) when (ex.Code == 11) @@ -109,7 +110,8 @@ internal static async Task GetResultAsync( { try { - var helloResultDocument = await helloProtocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var helloResultDocument = await helloProtocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); return new HelloResult(helloResultDocument); } catch (MongoCommandException ex) when (ex.Code == 11) diff --git a/src/MongoDB.Driver/Core/Connections/IConnection.cs b/src/MongoDB.Driver/Core/Connections/IConnection.cs index a82dfc3eda3..5a7af78169f 100644 --- a/src/MongoDB.Driver/Core/Connections/IConnection.cs +++ b/src/MongoDB.Driver/Core/Connections/IConnection.cs @@ -14,7 +14,6 @@ */ using System; -using System.Collections.Generic; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -33,15 +32,16 @@ internal interface IConnection : IDisposable bool IsExpired { get; } ConnectionSettings Settings { get; } + // TODO: CSOT: remove this in scope of MongoDB Handshake void SetReadTimeout(TimeSpan timeout); - void Open(CancellationToken cancellationToken); - Task OpenAsync(CancellationToken cancellationToken); + void Open(OperationContext operationContext); + Task OpenAsync(OperationContext operationContext); void Reauthenticate(CancellationToken cancellationToken); Task ReauthenticateAsync(CancellationToken cancellationToken); - ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); + ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings); + Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings); + void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings); + Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings); } internal interface IConnectionHandle : IConnection diff --git a/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs b/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs index 70e6895661d..6ef047c725a 100644 --- a/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs +++ b/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs @@ -39,12 +39,23 @@ public sealed class SemaphoreSlimRequest : IDisposable /// The semaphore. /// The cancellation token. public SemaphoreSlimRequest(SemaphoreSlim semaphore, CancellationToken cancellationToken) + : this(semaphore, Timeout.InfiniteTimeSpan, cancellationToken) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The semaphore. + /// The timeout. + /// The cancellation token. + public SemaphoreSlimRequest(SemaphoreSlim semaphore, TimeSpan timeout, CancellationToken cancellationToken) { _semaphore = Ensure.IsNotNull(semaphore, nameof(semaphore)); _disposeCancellationTokenSource = new CancellationTokenSource(); _linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposeCancellationTokenSource.Token); - _task = semaphore.WaitAsync(_linkedCancellationTokenSource.Token); + _task = semaphore.WaitAsync(timeout, _linkedCancellationTokenSource.Token); } // public properties diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 1cb4cd5181f..1c331546c6b 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -109,16 +109,20 @@ public static async Task ReadAsync(this Stream stream, byte[] buffer, int o } } - public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, OperationContext operationContext, byte[] buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - var bytesRead = stream.Read(buffer, offset, count, timeout, cancellationToken); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; + var bytesRead = stream.Read(buffer, offset, count, timeout, operationContext.CancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -128,18 +132,22 @@ public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int } } - public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToRead = Math.Min(count, backingBytes.Count); - var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken); + var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, operationContext.CancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -149,16 +157,20 @@ public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, } } - public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task ReadBytesAsync(this Stream stream, OperationContext operationContext, byte[] buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - var bytesRead = await stream.ReadAsync(buffer, offset, count, timeout, cancellationToken).ConfigureAwait(false); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; + var bytesRead = await stream.ReadAsync(buffer, offset, count, timeout, operationContext.CancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -168,18 +180,22 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int o } } - public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task ReadBytesAsync(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToRead = Math.Min(count, backingBytes.Count); - var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken).ConfigureAwait(false); + var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, operationContext.CancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -264,36 +280,43 @@ public static async Task WriteAsync(this Stream stream, byte[] buffer, int offse } } - public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void WriteBytes(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.WriteTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - cancellationToken.ThrowIfCancellationRequested(); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToWrite = Math.Min(count, backingBytes.Count); - stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken); + stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, operationContext.CancellationToken); offset += bytesToWrite; count -= bytesToWrite; } } - public static async Task WriteBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task WriteBytesAsync(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsRootContextTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.WriteTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToWrite = Math.Min(count, backingBytes.Count); - await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken).ConfigureAwait(false); + await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, operationContext.CancellationToken).ConfigureAwait(false); offset += bytesToWrite; count -= bytesToWrite; } diff --git a/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs b/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs index bb0b2daa92a..dd5c7e0ba9f 100644 --- a/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs +++ b/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs @@ -219,7 +219,10 @@ private CursorBatch ExecuteGetMoreCommand(IChannelHandle channel, Can BsonDocument result; try { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); result = channel.Command( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -230,8 +233,7 @@ private CursorBatch ExecuteGetMoreCommand(IChannelHandle channel, Can null, // postWriteAction CommandResponseHandling.Return, __getMoreCommandResultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } catch (MongoCommandException ex) when (IsMongoCursorNotFoundException(ex)) { @@ -247,7 +249,10 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa BsonDocument result; try { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); result = await channel.CommandAsync( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -258,8 +263,7 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa null, // postWriteAction CommandResponseHandling.Return, __getMoreCommandResultSerializer, - _messageEncoderSettings, - cancellationToken).ConfigureAwait(false); + _messageEncoderSettings).ConfigureAwait(false); } catch (MongoCommandException ex) when (IsMongoCursorNotFoundException(ex)) { @@ -271,8 +275,11 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa private void ExecuteKillCursorsCommand(IChannelHandle channel, CancellationToken cancellationToken) { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); var command = CreateKillCursorsCommand(); var result = channel.Command( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -283,16 +290,18 @@ private void ExecuteKillCursorsCommand(IChannelHandle channel, CancellationToken null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); ThrowIfKillCursorsCommandFailed(result, channel.ConnectionDescription.ConnectionId); } private async Task ExecuteKillCursorsCommandAsync(IChannelHandle channel, CancellationToken cancellationToken) { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); var command = CreateKillCursorsCommand(); var result = await channel.CommandAsync( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -303,8 +312,7 @@ private async Task ExecuteKillCursorsCommandAsync(IChannelHandle channel, Cancel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken) + _messageEncoderSettings) .ConfigureAwait(false); ThrowIfKillCursorsCommandFailed(result, channel.ConnectionDescription.ConnectionId); diff --git a/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs b/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs index 62eca2d6992..e5f94ebe200 100644 --- a/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs @@ -13,7 +13,6 @@ * limitations under the License. */ -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -85,11 +84,12 @@ public IBsonSerializer ResultSerializer get { return _resultSerializer; } } - protected TCommandResult ExecuteProtocol(IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference, CancellationToken cancellationToken) + protected TCommandResult ExecuteProtocol(OperationContext operationContext, IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference) { var additionalOptions = GetEffectiveAdditionalOptions(); return channel.Command( + operationContext, session, readPreference, _databaseNamespace, @@ -100,8 +100,7 @@ protected TCommandResult ExecuteProtocol(IChannelHandle channel, ICoreSessionHan null, // postWriteAction, CommandResponseHandling.Return, _resultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } protected TCommandResult ExecuteProtocol( @@ -112,15 +111,16 @@ protected TCommandResult ExecuteProtocol( { using (var channel = channelSource.GetChannel(operationContext)) { - return ExecuteProtocol(channel, session, readPreference, operationContext.CancellationToken); + return ExecuteProtocol(operationContext, channel, session, readPreference); } } - protected Task ExecuteProtocolAsync(IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference, CancellationToken cancellationToken) + protected Task ExecuteProtocolAsync(OperationContext operationContext, IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference) { var additionalOptions = GetEffectiveAdditionalOptions(); return channel.CommandAsync( + operationContext, session, readPreference, _databaseNamespace, @@ -131,8 +131,7 @@ protected Task ExecuteProtocolAsync(IChannelHandle channel, ICor null, // postWriteAction, CommandResponseHandling.Return, _resultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } protected async Task ExecuteProtocolAsync( @@ -143,7 +142,7 @@ protected async Task ExecuteProtocolAsync( { using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) { - return await ExecuteProtocolAsync(channel, session, readPreference, operationContext.CancellationToken).ConfigureAwait(false); + return await ExecuteProtocolAsync(operationContext, channel, session, readPreference).ConfigureAwait(false); } } diff --git a/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs b/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs index 84305df6f25..711d31c37df 100644 --- a/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs @@ -84,12 +84,12 @@ public async Task ExecuteAsync(OperationContext operationContext public TCommandResult ExecuteAttempt(OperationContext operationContext, RetryableReadContext context, int attempt, long? transactionNumber) { - return ExecuteProtocol(context.Channel, context.Binding.Session, context.Binding.ReadPreference, operationContext.CancellationToken); + return ExecuteProtocol(operationContext, context.Channel, context.Binding.Session, context.Binding.ReadPreference); } public Task ExecuteAttemptAsync(OperationContext operationContext, RetryableReadContext context, int attempt, long? transactionNumber) { - return ExecuteProtocolAsync(context.Channel, context.Binding.Session, context.Binding.ReadPreference, operationContext.CancellationToken); + return ExecuteProtocolAsync(operationContext, context.Channel, context.Binding.Session, context.Binding.ReadPreference); } } } diff --git a/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs b/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs index bb3ec94e32d..bcb0e72b291 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs @@ -115,8 +115,8 @@ public virtual Task ExecuteAsync(OperationContext operationContext public BsonDocument ExecuteAttempt(OperationContext operationContext, RetryableWriteContext context, int attempt, long? transactionNumber) { var args = GetCommandArgs(context, attempt, transactionNumber); - // TODO: CSOT implement timeout in Command Execution return context.Channel.Command( + operationContext, context.ChannelSource.Session, ReadPreference.Primary, _databaseNamespace, @@ -127,15 +127,14 @@ public BsonDocument ExecuteAttempt(OperationContext operationContext, RetryableW args.PostWriteAction, args.ResponseHandling, BsonDocumentSerializer.Instance, - args.MessageEncoderSettings, - operationContext.CancellationToken); + args.MessageEncoderSettings); } public Task ExecuteAttemptAsync(OperationContext operationContext, RetryableWriteContext context, int attempt, long? transactionNumber) { var args = GetCommandArgs(context, attempt, transactionNumber); - // TODO: CSOT implement timeout in Command Execution return context.Channel.CommandAsync( + operationContext, context.ChannelSource.Session, ReadPreference.Primary, _databaseNamespace, @@ -146,8 +145,7 @@ public Task ExecuteAttemptAsync(OperationContext operationContext, args.PostWriteAction, args.ResponseHandling, BsonDocumentSerializer.Instance, - args.MessageEncoderSettings, - operationContext.CancellationToken); + args.MessageEncoderSettings); } protected abstract BsonDocument CreateCommand(ICoreSessionHandle session, int attempt, long? transactionNumber); diff --git a/src/MongoDB.Driver/Core/Servers/IServer.cs b/src/MongoDB.Driver/Core/Servers/IServer.cs index da1e6f49138..643a5883126 100644 --- a/src/MongoDB.Driver/Core/Servers/IServer.cs +++ b/src/MongoDB.Driver/Core/Servers/IServer.cs @@ -15,9 +15,10 @@ using System; using System.Net; -using System.Threading; using System.Threading.Tasks; using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Clusters; +using MongoDB.Driver.Core.Connections; namespace MongoDB.Driver.Core.Servers { @@ -25,12 +26,16 @@ internal interface IServer { event EventHandler DescriptionChanged; + IClusterClock ClusterClock { get; } ServerDescription Description { get; } EndPoint EndPoint { get; } ServerId ServerId { get; } + ServerApi ServerApi { get; } + void DecrementOutstandingOperationsCount(); IChannelHandle GetChannel(OperationContext operationContext); Task GetChannelAsync(OperationContext operationContext); + void HandleChannelException(IConnectionHandle connection, Exception exception); } internal interface IClusterableServer : IServer, IDisposable @@ -42,4 +47,9 @@ internal interface IClusterableServer : IServer, IDisposable void Invalidate(string reasonInvalidated, TopologyVersion responseTopologyVersion); void RequestHeartbeat(); } + + internal interface ISelectedServer : IServer + { + ServerDescription DescriptionWhenSelected { get; } + } } diff --git a/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs b/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs index 23a306a9dcf..4ac2bba1bf7 100644 --- a/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs +++ b/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs @@ -170,7 +170,9 @@ private void InitializeConnection() { // if we are cancelling, it's because the server has // been shut down and we really don't need to wait. - roundTripTimeConnection.Open(_cancellationToken); + // TODO: CSOT: Implement proper operation context handling in scope of Server Discovery and Monitoring + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, _cancellationToken); + roundTripTimeConnection.Open(operationContext); _cancellationToken.ThrowIfCancellationRequested(); } catch diff --git a/src/MongoDB.Driver/Core/Servers/SelectedServer.cs b/src/MongoDB.Driver/Core/Servers/SelectedServer.cs new file mode 100644 index 00000000000..c0e8b6edbeb --- /dev/null +++ b/src/MongoDB.Driver/Core/Servers/SelectedServer.cs @@ -0,0 +1,65 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Net; +using System.Threading.Tasks; +using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Clusters; +using MongoDB.Driver.Core.Connections; + +namespace MongoDB.Driver.Core.Servers; + +internal class SelectedServer : ISelectedServer +{ + private readonly ServerDescription _descriptionWhenSelected; + private readonly IServer _server; + + public SelectedServer(IServer server, ServerDescription descriptionWhenSelected) + { + _server = server; + _descriptionWhenSelected = descriptionWhenSelected; + } + + public event EventHandler DescriptionChanged + { + add { _server.DescriptionChanged += value; } + remove => _server.DescriptionChanged -= value; + } + + public IClusterClock ClusterClock => _server.ClusterClock; + public ServerDescription Description => _server.Description; + public EndPoint EndPoint => _server.EndPoint; + public ServerId ServerId => _server.ServerId; + public ServerApi ServerApi => _server.ServerApi; + public ServerDescription DescriptionWhenSelected => _descriptionWhenSelected; + + public void DecrementOutstandingOperationsCount() + => _server.DecrementOutstandingOperationsCount(); + + public IChannelHandle GetChannel(OperationContext operationContext) + { + var channel = _server.GetChannel(operationContext); + return new ServerChannel(this, channel.Connection); + } + + public async Task GetChannelAsync(OperationContext operationContext) + { + var channel = await _server.GetChannelAsync(operationContext).ConfigureAwait(false); + return new ServerChannel(this, channel.Connection); + } + + public void HandleChannelException(IConnectionHandle channel, Exception exception) => _server.HandleChannelException(channel, exception); +} diff --git a/src/MongoDB.Driver/Core/Servers/Server.cs b/src/MongoDB.Driver/Core/Servers/Server.cs index c7fe1a94b28..50c1710e566 100644 --- a/src/MongoDB.Driver/Core/Servers/Server.cs +++ b/src/MongoDB.Driver/Core/Servers/Server.cs @@ -14,14 +14,10 @@ */ using System; -using System.Collections.Generic; using System.Diagnostics; using System.Net; using System.Threading; using System.Threading.Tasks; -using MongoDB.Bson; -using MongoDB.Bson.IO; -using MongoDB.Bson.Serialization; using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Configuration; @@ -30,9 +26,6 @@ using MongoDB.Driver.Core.Events; using MongoDB.Driver.Core.Logging; using MongoDB.Driver.Core.Misc; -using MongoDB.Driver.Core.WireProtocol; -using MongoDB.Driver.Core.WireProtocol.Messages; -using MongoDB.Driver.Core.WireProtocol.Messages.Encoders; namespace MongoDB.Driver.Core.Servers { @@ -82,6 +75,7 @@ public Server( public abstract ServerDescription Description { get; } public EndPoint EndPoint => _endPoint; public bool IsInitialized => _state.Value != State.Initial; + public ServerApi ServerApi => _serverApi; public ServerId ServerId => _serverId; protected EventLogger EventLogger => _eventLogger; @@ -104,6 +98,38 @@ public void Dispose() } } + public void DecrementOutstandingOperationsCount() + { + Interlocked.Decrement(ref _outstandingOperationsCount); + } + + public void HandleChannelException(IConnectionHandle connection, Exception ex) + { + if (!IsOpen() || ShouldIgnoreException(ex)) + { + return; + } + + ex = GetEffectiveException(ex); + + HandleAfterHandshakeCompletesException(connection, ex); + + bool ShouldIgnoreException(Exception ex) + { + // For most connection exceptions, we are going to immediately + // invalidate the server. However, we aren't going to invalidate + // because of OperationCanceledExceptions. We trust that the + // implementations of connection don't leave themselves in a state + // where they can't be used based on user cancellation. + return ex is OperationCanceledException; + } + + Exception GetEffectiveException(Exception ex) => + ex is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1 + ? aggregateException.InnerException + : ex; + } + public void HandleExceptionOnOpen(Exception exception) => HandleBeforeHandshakeCompletesException(exception); @@ -114,7 +140,6 @@ public IChannelHandle GetChannel(OperationContext operationContext) try { Interlocked.Increment(ref _outstandingOperationsCount); - var connection = _connectionPool.AcquireConnection(operationContext); return new ServerChannel(this, connection); } @@ -173,7 +198,6 @@ public void Invalidate(string reasonInvalidated, TopologyVersion responseTopolog public abstract void RequestHeartbeat(); // protected methods - protected abstract void Invalidate(string reasonInvalidated, bool clearConnectionPool, TopologyVersion responseTopologyDescription); protected abstract void Dispose(bool disposing); @@ -222,33 +246,6 @@ protected bool ShouldClearConnectionPoolForChannelException(Exception ex, int ma } // private methods - private void HandleChannelException(IConnection connection, Exception ex) - { - if (!IsOpen() || ShouldIgnoreException(ex)) - { - return; - } - - ex = GetEffectiveException(ex); - - HandleAfterHandshakeCompletesException(connection, ex); - - bool ShouldIgnoreException(Exception ex) - { - // For most connection exceptions, we are going to immediately - // invalidate the server. However, we aren't going to invalidate - // because of OperationCanceledExceptions. We trust that the - // implementations of connection don't leave themselves in a state - // where they can't be used based on user cancellation. - return ex is OperationCanceledException; - } - - Exception GetEffectiveException(Exception ex) => - ex is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1 - ? aggregateException.InnerException - : ex; - } - private bool IsOpen() => _state.Value == State.Open; private void ThrowIfDisposed() @@ -275,172 +272,5 @@ private static class State public const int Open = 1; public const int Disposed = 2; } - - private sealed class ServerChannel : IChannelHandle - { - // fields - private readonly IConnectionHandle _connection; - private readonly Server _server; - - private readonly InterlockedInt32 _state; - private readonly bool _decrementOperationsCount; - - // constructors - public ServerChannel(Server server, IConnectionHandle connection, bool decrementOperationsCount = true) - { - _server = server; - _connection = connection; - - _state = new InterlockedInt32(ChannelState.Initial); - _decrementOperationsCount = decrementOperationsCount; - } - - // properties - public IConnectionHandle Connection => _connection; - - public ConnectionDescription ConnectionDescription - { - get { return _connection.Description; } - } - - // methods - public TResult Command( - ICoreSession session, - ReadPreference readPreference, - DatabaseNamespace databaseNamespace, - BsonDocument command, - IEnumerable commandPayloads, - IElementNameValidator commandValidator, - BsonDocument additionalOptions, - Action postWriteAction, - CommandResponseHandling responseHandling, - IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) - { - var protocol = new CommandWireProtocol( - CreateClusterClockAdvancingCoreSession(session), - readPreference, - databaseNamespace, - command, - commandPayloads, - commandValidator, - additionalOptions, - postWriteAction, - responseHandling, - resultSerializer, - messageEncoderSettings, - _server._serverApi); - - return ExecuteProtocol(protocol, session, cancellationToken); - } - - public Task CommandAsync( - ICoreSession session, - ReadPreference readPreference, - DatabaseNamespace databaseNamespace, - BsonDocument command, - IEnumerable commandPayloads, - IElementNameValidator commandValidator, - BsonDocument additionalOptions, - Action postWriteAction, - CommandResponseHandling responseHandling, - IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) - { - var protocol = new CommandWireProtocol( - CreateClusterClockAdvancingCoreSession(session), - readPreference, - databaseNamespace, - command, - commandPayloads, - commandValidator, - additionalOptions, - postWriteAction, - responseHandling, - resultSerializer, - messageEncoderSettings, - _server._serverApi); - - return ExecuteProtocolAsync(protocol, session, cancellationToken); - } - - public void Dispose() - { - if (_state.TryChange(ChannelState.Initial, ChannelState.Disposed)) - { - if (_decrementOperationsCount) - { - Interlocked.Decrement(ref _server._outstandingOperationsCount); - } - - _connection.Dispose(); - } - } - - private ICoreSession CreateClusterClockAdvancingCoreSession(ICoreSession session) - { - return new ClusterClockAdvancingCoreSession(session, _server.ClusterClock); - } - - private TResult ExecuteProtocol(IWireProtocol protocol, ICoreSession session, CancellationToken cancellationToken) - { - try - { - return protocol.Execute(_connection, cancellationToken); - } - catch (Exception ex) - { - MarkSessionDirtyIfNeeded(session, ex); - _server.HandleChannelException(_connection, ex); - throw; - } - } - - private async Task ExecuteProtocolAsync(IWireProtocol protocol, ICoreSession session, CancellationToken cancellationToken) - { - try - { - return await protocol.ExecuteAsync(_connection, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - MarkSessionDirtyIfNeeded(session, ex); - _server.HandleChannelException(_connection, ex); - throw; - } - } - - public IChannelHandle Fork() - { - ThrowIfDisposed(); - - return new ServerChannel(_server, _connection.Fork(), false); - } - - private void MarkSessionDirtyIfNeeded(ICoreSession session, Exception ex) - { - if (ex is MongoConnectionException) - { - session.MarkDirty(); - } - } - - private void ThrowIfDisposed() - { - if (_state.Value == ChannelState.Disposed) - { - throw new ObjectDisposedException(GetType().Name); - } - } - - // nested types - private static class ChannelState - { - public const int Initial = 0; - public const int Disposed = 1; - } - } } } diff --git a/src/MongoDB.Driver/Core/Servers/ServerChannel.cs b/src/MongoDB.Driver/Core/Servers/ServerChannel.cs new file mode 100644 index 00000000000..bf0ce569e87 --- /dev/null +++ b/src/MongoDB.Driver/Core/Servers/ServerChannel.cs @@ -0,0 +1,206 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Connections; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.WireProtocol; +using MongoDB.Driver.Core.WireProtocol.Messages; +using MongoDB.Driver.Core.WireProtocol.Messages.Encoders; + +namespace MongoDB.Driver.Core.Servers +{ + internal sealed class ServerChannel : IChannelHandle + { + // fields + private readonly IConnectionHandle _connection; + private readonly IServer _server; + private readonly InterlockedInt32 _state; + private readonly bool _decrementOperationsCount; + + // constructors + public ServerChannel(IServer server, IConnectionHandle connection, bool decrementOperationsCount = true) + { + _server = server; + _connection = connection; + _state = new InterlockedInt32(ChannelState.Initial); + _decrementOperationsCount = decrementOperationsCount; + } + + // properties + public IConnectionHandle Connection => _connection; + + public ConnectionDescription ConnectionDescription => _connection.Description; + + // methods + public TResult Command( + OperationContext operationContext, + ICoreSession session, + ReadPreference readPreference, + DatabaseNamespace databaseNamespace, + BsonDocument command, + IEnumerable commandPayloads, + IElementNameValidator commandValidator, + BsonDocument additionalOptions, + Action postWriteAction, + CommandResponseHandling responseHandling, + IBsonSerializer resultSerializer, + MessageEncoderSettings messageEncoderSettings) + { + var roundTripTime = TimeSpan.Zero; + if (_server is ISelectedServer selectedServer) + { + roundTripTime = selectedServer.DescriptionWhenSelected.AverageRoundTripTime; + } + + var protocol = new CommandWireProtocol( + CreateClusterClockAdvancingCoreSession(session), + readPreference, + databaseNamespace, + command, + commandPayloads, + commandValidator, + additionalOptions, + postWriteAction, + responseHandling, + resultSerializer, + messageEncoderSettings, + _server.ServerApi, + roundTripTime); + + return ExecuteProtocol(operationContext, protocol, session); + } + + public Task CommandAsync( + OperationContext operationContext, + ICoreSession session, + ReadPreference readPreference, + DatabaseNamespace databaseNamespace, + BsonDocument command, + IEnumerable commandPayloads, + IElementNameValidator commandValidator, + BsonDocument additionalOptions, + Action postWriteAction, + CommandResponseHandling responseHandling, + IBsonSerializer resultSerializer, + MessageEncoderSettings messageEncoderSettings) + { + var roundTripTime = TimeSpan.Zero; + if (_server is ISelectedServer selectedServer) + { + roundTripTime = selectedServer.DescriptionWhenSelected.AverageRoundTripTime; + } + + var protocol = new CommandWireProtocol( + CreateClusterClockAdvancingCoreSession(session), + readPreference, + databaseNamespace, + command, + commandPayloads, + commandValidator, + additionalOptions, + postWriteAction, + responseHandling, + resultSerializer, + messageEncoderSettings, + _server.ServerApi, + roundTripTime); + + return ExecuteProtocolAsync(operationContext, protocol, session); + } + + public void Dispose() + { + if (_state.TryChange(ChannelState.Initial, ChannelState.Disposed)) + { + if (_decrementOperationsCount) + { + _server.DecrementOutstandingOperationsCount(); + } + + _connection.Dispose(); + } + } + + private ICoreSession CreateClusterClockAdvancingCoreSession(ICoreSession session) + { + return new ClusterClockAdvancingCoreSession(session, _server.ClusterClock); + } + + private TResult ExecuteProtocol(OperationContext operationContext, IWireProtocol protocol, ICoreSession session) + { + try + { + return protocol.Execute(operationContext, _connection); + } + catch (Exception ex) + { + MarkSessionDirtyIfNeeded(session, ex); + _server.HandleChannelException(_connection, ex); + throw; + } + } + + private async Task ExecuteProtocolAsync(OperationContext operationContext, IWireProtocol protocol, ICoreSession session) + { + try + { + return await protocol.ExecuteAsync(operationContext, _connection).ConfigureAwait(false); + } + catch (Exception ex) + { + MarkSessionDirtyIfNeeded(session, ex); + _server.HandleChannelException(_connection, ex); + throw; + } + } + + public IChannelHandle Fork() + { + ThrowIfDisposed(); + + return new ServerChannel(_server, _connection.Fork(), false); + } + + private void MarkSessionDirtyIfNeeded(ICoreSession session, Exception ex) + { + if (ex is MongoConnectionException) + { + session.MarkDirty(); + } + } + + private void ThrowIfDisposed() + { + if (_state.Value == ChannelState.Disposed) + { + throw new ObjectDisposedException(GetType().Name); + } + } + + // nested types + private static class ChannelState + { + public const int Initial = 0; + public const int Disposed = 1; + } + } +} diff --git a/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs b/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs index ca043166746..e2e7c64163a 100644 --- a/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs +++ b/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs @@ -216,7 +216,8 @@ private IConnection InitializeConnection(CancellationToken cancellationToken) // { // if we are cancelling, it's because the server has // been shut down and we really don't need to wait. - connection.Open(cancellationToken); + // TODO: CSOT: Implement operation context support for Server Discovery and Monitoring + connection.Open(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken)); _eventLoggerSdam.LogAndPublish(new ServerHeartbeatSucceededEvent(connection.ConnectionId, stopwatch.Elapsed, false, connection.Description.HelloResult.Wrapped)); } diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs index 159fe57c70c..276610571d1 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs @@ -49,6 +49,7 @@ internal sealed class CommandUsingCommandMessageWireProtocol : I private readonly CommandResponseHandling _responseHandling; private readonly IBsonSerializer _resultSerializer; private readonly ServerApi _serverApi; + private readonly TimeSpan _roundTripTime; private readonly ICoreSession _session; // streamable fields private bool _moreToCome = false; // MoreToCome from the previous response @@ -67,7 +68,8 @@ public CommandUsingCommandMessageWireProtocol( IBsonSerializer resultSerializer, MessageEncoderSettings messageEncoderSettings, Action postWriteAction, - ServerApi serverApi) + ServerApi serverApi, + TimeSpan roundTripTime) { if (responseHandling != CommandResponseHandling.Return && responseHandling != CommandResponseHandling.NoResponseExpected && @@ -88,6 +90,7 @@ public CommandUsingCommandMessageWireProtocol( _messageEncoderSettings = messageEncoderSettings; _postWriteAction = postWriteAction; // can be null _serverApi = serverApi; // can be null + _roundTripTime = roundTripTime; if (messageEncoderSettings != null) { @@ -100,7 +103,7 @@ public CommandUsingCommandMessageWireProtocol( public bool MoreToCome => _moreToCome; // public methods - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { try { @@ -113,19 +116,21 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella } else { - message = CreateCommandMessage(connection.Description); - message = AutoEncryptFieldsIfNecessary(message, connection, cancellationToken); + message = CreateCommandMessage(operationContext, connection.Description); + // TODO: CSOT: Propagate operationContext into Encryption + message = AutoEncryptFieldsIfNecessary(message, connection, operationContext.CancellationToken); responseTo = message.WrappedMessage.RequestId; } try { - return SendMessageAndProcessResponse(message, responseTo, connection, cancellationToken); + return SendMessageAndProcessResponse(operationContext, message, responseTo, connection); } catch (MongoCommandException commandException) when (RetryabilityHelper.IsReauthenticationRequested(commandException, _command)) { - connection.Reauthenticate(cancellationToken); - return SendMessageAndProcessResponse(message, responseTo, connection, cancellationToken); + // TODO: CSOT: support operationContext in auth + connection.Reauthenticate(operationContext.CancellationToken); + return SendMessageAndProcessResponse(operationContext, message, responseTo, connection); } } catch (Exception exception) @@ -137,7 +142,7 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella } } - public async Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public async Task ExecuteAsync(OperationContext operationContext, IConnection connection) { try { @@ -150,19 +155,21 @@ public async Task ExecuteAsync(IConnection connection, Cancellat } else { - message = CreateCommandMessage(connection.Description); - message = await AutoEncryptFieldsIfNecessaryAsync(message, connection, cancellationToken).ConfigureAwait(false); + message = CreateCommandMessage(operationContext, connection.Description); + // TODO: CSOT: Propagate operationContext into Encryption + message = await AutoEncryptFieldsIfNecessaryAsync(message, connection, operationContext.CancellationToken).ConfigureAwait(false); responseTo = message.WrappedMessage.RequestId; } try { - return await SendMessageAndProcessResponseAsync(message, responseTo, connection, cancellationToken).ConfigureAwait(false); + return await SendMessageAndProcessResponseAsync(operationContext, message, responseTo, connection).ConfigureAwait(false); } catch (MongoCommandException commandException) when (RetryabilityHelper.IsReauthenticationRequested(commandException, _command)) { - await connection.ReauthenticateAsync(cancellationToken).ConfigureAwait(false); - return await SendMessageAndProcessResponseAsync(message, responseTo, connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: support operationContext in auth + await connection.ReauthenticateAsync(operationContext.CancellationToken).ConfigureAwait(false); + return await SendMessageAndProcessResponseAsync(operationContext, message, responseTo, connection).ConfigureAwait(false); } } catch (Exception exception) @@ -253,11 +260,11 @@ private async Task AutoEncryptFieldsIfNecessaryAsync(Comm } } - private CommandRequestMessage CreateCommandMessage(ConnectionDescription connectionDescription) + private CommandRequestMessage CreateCommandMessage(OperationContext operationContext, ConnectionDescription connectionDescription) { var requestId = RequestMessage.GetNextRequestId(); var responseTo = 0; - var sections = CreateSections(connectionDescription); + var sections = CreateSections(operationContext, connectionDescription); var moreToComeRequest = _responseHandling == CommandResponseHandling.NoResponseExpected; @@ -270,9 +277,9 @@ private CommandRequestMessage CreateCommandMessage(ConnectionDescription connect return new CommandRequestMessage(wrappedMessage); } - private IEnumerable CreateSections(ConnectionDescription connectionDescription) + private IEnumerable CreateSections(OperationContext operationContext, ConnectionDescription connectionDescription) { - var type0Section = CreateType0Section(connectionDescription); + var type0Section = CreateType0Section(operationContext, connectionDescription); if (_commandPayloads == null) { return new[] { type0Section }; @@ -283,7 +290,7 @@ private IEnumerable CreateSections(ConnectionDescription } } - private Type0CommandMessageSection CreateType0Section(ConnectionDescription connectionDescription) + private Type0CommandMessageSection CreateType0Section(OperationContext operationContext, ConnectionDescription connectionDescription) { var extraElements = new List(); @@ -369,6 +376,17 @@ private Type0CommandMessageSection CreateType0Section(ConnectionDe } } + if (operationContext.IsRootContextTimeoutConfigured()) + { + var serverTimeout = operationContext.RemainingTimeout - _roundTripTime; + if (serverTimeout < TimeSpan.Zero) + { + throw new TimeoutException(); + } + + AddIfNotAlreadyAdded("maxTimeMS", (int)serverTimeout.TotalMilliseconds); + } + var elementAppendingSerializer = new ElementAppendingSerializer(BsonDocumentSerializer.Instance, extraElements); return new Type0CommandMessageSection(_command, elementAppendingSerializer); @@ -526,14 +544,15 @@ private void SaveResponseInfo(CommandResponseMessage response) _moreToCome = response.WrappedMessage.MoreToCome; } - private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage message, int responseTo, IConnection connection, CancellationToken cancellationToken) + private TCommandResult SendMessageAndProcessResponse(OperationContext operationContext, CommandRequestMessage message, int responseTo, IConnection connection) { var responseExpected = true; if (message != null) { try { - connection.SendMessage(message, _messageEncoderSettings, cancellationToken); + ThrowIfRemainingTimeoutLessThenRoundTripTime(operationContext); + connection.SendMessage(operationContext, message, _messageEncoderSettings); } finally { @@ -549,8 +568,9 @@ private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage messa if (responseExpected) { var encoderSelector = new CommandResponseMessageEncoderSelector(); - var response = (CommandResponseMessage)connection.ReceiveMessage(responseTo, encoderSelector, _messageEncoderSettings, cancellationToken); - response = AutoDecryptFieldsIfNecessary(response, cancellationToken); + var response = (CommandResponseMessage)connection.ReceiveMessage(operationContext, responseTo, encoderSelector, _messageEncoderSettings); + // TODO: CSOT: Propagate operationContext into Encryption + response = AutoDecryptFieldsIfNecessary(response, operationContext.CancellationToken); var result = ProcessResponse(connection.ConnectionId, response.WrappedMessage); SaveResponseInfo(response); return result; @@ -561,14 +581,15 @@ private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage messa } } - private async Task SendMessageAndProcessResponseAsync(CommandRequestMessage message, int responseTo, IConnection connection, CancellationToken cancellationToken) + private async Task SendMessageAndProcessResponseAsync(OperationContext operationContext, CommandRequestMessage message, int responseTo, IConnection connection) { var responseExpected = true; if (message != null) { try { - await connection.SendMessageAsync(message, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + ThrowIfRemainingTimeoutLessThenRoundTripTime(operationContext); + await connection.SendMessageAsync(operationContext, message, _messageEncoderSettings).ConfigureAwait(false); } finally { @@ -583,8 +604,9 @@ private async Task SendMessageAndProcessResponseAsync(CommandReq if (responseExpected) { var encoderSelector = new CommandResponseMessageEncoderSelector(); - var response = (CommandResponseMessage)await connection.ReceiveMessageAsync(responseTo, encoderSelector, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); - response = await AutoDecryptFieldsIfNecessaryAsync(response, cancellationToken).ConfigureAwait(false); + var response = (CommandResponseMessage)await connection.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, _messageEncoderSettings).ConfigureAwait(false); + // TODO: CSOT: Propagate operationContext into Encryption + response = await AutoDecryptFieldsIfNecessaryAsync(response, operationContext.CancellationToken).ConfigureAwait(false); var result = ProcessResponse(connection.ConnectionId, response.WrappedMessage); SaveResponseInfo(response); return result; @@ -608,6 +630,16 @@ private bool ShouldAddTransientTransactionError(MongoException exception) return false; } + private void ThrowIfRemainingTimeoutLessThenRoundTripTime(OperationContext operationContext) + { + if (operationContext.RemainingTimeout == Timeout.InfiniteTimeSpan || operationContext.RemainingTimeout > _roundTripTime) + { + return; + } + + throw new TimeoutException(); + } + private MongoException WrapNotSupportedRetryableWriteException(MongoCommandException exception) { const string friendlyErrorMessage = diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs index 6d18ada9747..ecfb53f0f70 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs @@ -18,7 +18,6 @@ using System.Linq; using System.Reflection; using System.Text; -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -116,11 +115,11 @@ private QueryMessage CreateMessage(ConnectionDescription connectionDescription, #pragma warning restore 618 } - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { bool messageContainsSessionId; var message = CreateMessage(connection.Description, out messageContainsSessionId); - connection.SendMessage(message, _messageEncoderSettings, cancellationToken); + connection.SendMessage(operationContext, message, _messageEncoderSettings); if (messageContainsSessionId) { _session.WasUsed(); @@ -129,20 +128,20 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella switch (message.ResponseHandling) { case CommandResponseHandling.Ignore: - IgnoreResponse(connection, message, cancellationToken); + IgnoreResponse(operationContext, connection, message); return default(TCommandResult); default: var encoderSelector = new ReplyMessageEncoderSelector(RawBsonDocumentSerializer.Instance); - var reply = connection.ReceiveMessage(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken); + var reply = connection.ReceiveMessage(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings); return ProcessReply(connection.ConnectionId, (ReplyMessage)reply); } } - public async Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public async Task ExecuteAsync(OperationContext operationContext, IConnection connection) { bool messageContainsSessionId; var message = CreateMessage(connection.Description, out messageContainsSessionId); - await connection.SendMessageAsync(message, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + await connection.SendMessageAsync(operationContext, message, _messageEncoderSettings).ConfigureAwait(false); if (messageContainsSessionId) { _session.WasUsed(); @@ -151,11 +150,11 @@ public async Task ExecuteAsync(IConnection connection, Cancellat switch (message.ResponseHandling) { case CommandResponseHandling.Ignore: - IgnoreResponse(connection, message, cancellationToken); + IgnoreResponse(operationContext, connection, message); return default(TCommandResult); default: var encoderSelector = new ReplyMessageEncoderSelector(RawBsonDocumentSerializer.Instance); - var reply = await connection.ReceiveMessageAsync(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + var reply = await connection.ReceiveMessageAsync(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings).ConfigureAwait(false); return ProcessReply(connection.ConnectionId, (ReplyMessage)reply); } } @@ -230,10 +229,10 @@ private IBsonSerializer CreateSizeLimitingPayloadSerializer(Type1CommandMessageS return (IBsonSerializer)constructorInfo.Invoke(new object[] { itemSerializer, itemElementNameValidator, maxBatchCount, maxItemSize, maxBatchSize }); } - private void IgnoreResponse(IConnection connection, QueryMessage message, CancellationToken cancellationToken) + private void IgnoreResponse(OperationContext operationContext, IConnection connection, QueryMessage message) { var encoderSelector = new ReplyMessageEncoderSelector(IgnoredReplySerializer.Instance); - connection.ReceiveMessageAsync(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken).IgnoreExceptions(); + connection.ReceiveMessageAsync(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings).IgnoreExceptions(); } [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")] diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs index 3844f13aec0..093fd8b1ae7 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -45,6 +44,7 @@ internal sealed class CommandWireProtocol : IWireProtocol _resultSerializer; private readonly ServerApi _serverApi; + private readonly TimeSpan _roundTripTime; private readonly ICoreSession _session; // constructors @@ -86,7 +86,8 @@ public CommandWireProtocol( commandResponseHandling, resultSerializer, messageEncoderSettings, - serverApi) + serverApi, + roundTripTime: TimeSpan.Zero) { } @@ -102,7 +103,8 @@ public CommandWireProtocol( CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, MessageEncoderSettings messageEncoderSettings, - ServerApi serverApi) + ServerApi serverApi, + TimeSpan roundTripTime) { if (responseHandling != CommandResponseHandling.Return && responseHandling != CommandResponseHandling.NoResponseExpected && @@ -123,22 +125,23 @@ public CommandWireProtocol( _messageEncoderSettings = messageEncoderSettings; _postWriteAction = postWriteAction; // can be null _serverApi = serverApi; // can be null + _roundTripTime = roundTripTime; } // public properties public bool MoreToCome => _cachedWireProtocol?.MoreToCome ?? false; // public methods - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { var supportedProtocol = CreateSupportedWireProtocol(connection); - return supportedProtocol.Execute(connection, cancellationToken); + return supportedProtocol.Execute(operationContext, connection); } - public Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public Task ExecuteAsync(OperationContext operationContext, IConnection connection) { var supportedProtocol = CreateSupportedWireProtocol(connection); - return supportedProtocol.ExecuteAsync(connection, cancellationToken); + return supportedProtocol.ExecuteAsync(operationContext, connection); } // private methods @@ -156,7 +159,8 @@ private IWireProtocol CreateCommandUsingCommandMessageWireProtoc _resultSerializer, _messageEncoderSettings, _postWriteAction, - _serverApi); + _serverApi, + _roundTripTime); } private IWireProtocol CreateCommandUsingQueryMessageWireProtocol() diff --git a/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs index 025e7dd5acd..dee26e7dd87 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs @@ -13,7 +13,6 @@ * limitations under the License. */ -using System.Threading; using System.Threading.Tasks; using MongoDB.Driver.Core.Connections; @@ -22,14 +21,14 @@ namespace MongoDB.Driver.Core.WireProtocol internal interface IWireProtocol { bool MoreToCome { get; } - void Execute(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); - Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); + void Execute(OperationContext operationContext, IConnection connection); + Task ExecuteAsync(OperationContext operationContext, IConnection connection); } internal interface IWireProtocol { bool MoreToCome { get; } - TResult Execute(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); - Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); + TResult Execute(OperationContext operationContext, IConnection connection); + Task ExecuteAsync(OperationContext operationContext, IConnection connection); } } diff --git a/src/MongoDB.Driver/OperationContext.cs b/src/MongoDB.Driver/OperationContext.cs index d2359de1c9f..7611357aee8 100644 --- a/src/MongoDB.Driver/OperationContext.cs +++ b/src/MongoDB.Driver/OperationContext.cs @@ -21,11 +21,14 @@ namespace MongoDB.Driver { - internal sealed class OperationContext + internal sealed class OperationContext : IDisposable { // TODO: this static field is temporary here and will be removed in a future PRs in scope of CSOT. public static readonly OperationContext NoTimeout = new(System.Threading.Timeout.InfiniteTimeSpan, CancellationToken.None); + private CancellationTokenSource _remainingTimeoutCancellationTokenSource; + private CancellationTokenSource _combinedCancellationTokenSource; + public OperationContext(TimeSpan timeout, CancellationToken cancellationToken) : this(Stopwatch.StartNew(), timeout, cancellationToken) { @@ -62,21 +65,39 @@ public TimeSpan RemainingTimeout } } + [Obsolete("Do not use this property, unless it's needed to avoid breaking changes in public API")] + public CancellationToken CombinedCancellationToken + { + get + { + if (_combinedCancellationTokenSource != null) + { + return _combinedCancellationTokenSource.Token; + } + + if (RemainingTimeout == System.Threading.Timeout.InfiniteTimeSpan) + { + return CancellationToken; + } + + _remainingTimeoutCancellationTokenSource = new CancellationTokenSource(RemainingTimeout); + _combinedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, _remainingTimeoutCancellationTokenSource.Token); + return _combinedCancellationTokenSource.Token; + } + } private Stopwatch Stopwatch { get; } public TimeSpan Timeout { get; } - public bool IsTimedOut() + public void Dispose() { - var remainingTimeout = RemainingTimeout; - if (remainingTimeout == System.Threading.Timeout.InfiniteTimeSpan) - { - return false; - } - - return remainingTimeout == TimeSpan.Zero; + _remainingTimeoutCancellationTokenSource?.Dispose(); + _combinedCancellationTokenSource?.Dispose(); } + public bool IsTimedOut() + => RemainingTimeout == TimeSpan.Zero; + public void ThrowIfTimedOutOrCanceled() { CancellationToken.ThrowIfCancellationRequested(); @@ -95,7 +116,7 @@ public void WaitTask(Task task) } var timeout = RemainingTimeout; - if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero) + if (timeout == TimeSpan.Zero) { throw new TimeoutException(); } @@ -128,7 +149,7 @@ public async Task WaitTaskAsync(Task task) } var timeout = RemainingTimeout; - if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero) + if (timeout == TimeSpan.Zero) { throw new TimeoutException(); } diff --git a/src/MongoDB.Driver/OperationExecutor.cs b/src/MongoDB.Driver/OperationExecutor.cs index 7025097de70..84ddcf287bd 100644 --- a/src/MongoDB.Driver/OperationExecutor.cs +++ b/src/MongoDB.Driver/OperationExecutor.cs @@ -50,7 +50,7 @@ public TResult ExecuteReadOperation( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); var readPreference = options.GetEffectiveReadPreference(session); using var binding = CreateReadBinding(session, readPreference, allowChannelPinning); return operation.Execute(operationContext, binding); @@ -68,7 +68,7 @@ public async Task ExecuteReadOperationAsync( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); var readPreference = options.GetEffectiveReadPreference(session); using var binding = CreateReadBinding(session, readPreference, allowChannelPinning); return await operation.ExecuteAsync(operationContext, binding).ConfigureAwait(false); @@ -86,7 +86,7 @@ public TResult ExecuteWriteOperation( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); using var binding = CreateReadWriteBinding(session, allowChannelPinning); return operation.Execute(operationContext, binding); } @@ -103,7 +103,7 @@ public async Task ExecuteWriteOperationAsync( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); using var binding = CreateReadWriteBinding(session, allowChannelPinning); return await operation.ExecuteAsync(operationContext, binding).ConfigureAwait(false); } diff --git a/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs b/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs index 6ee9c5ad81a..f8298aaa62d 100644 --- a/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs +++ b/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs @@ -184,7 +184,7 @@ public List GetSentMessages() return _sentMessages; } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { _openingEventHandler?.Invoke(new ConnectionOpeningEvent(_connectionId, _connectionSettings, null)); @@ -196,7 +196,7 @@ public void Open(CancellationToken cancellationToken) _openedEventHandler?.Invoke(new ConnectionOpenedEvent(_connectionId, _connectionSettings, TimeSpan.FromTicks(1), null)); } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { _openingEventHandler?.Invoke(new ConnectionOpeningEvent(_connectionId, _connectionSettings, null)); @@ -220,24 +220,24 @@ public async Task ReauthenticateAsync(CancellationToken cancellationToken) await _replyActions.Dequeue().GetEffectiveMessageAsync().ConfigureAwait(false); } - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { var action = _replyActions.Dequeue(); return (ResponseMessage)action.GetEffectiveMessage(); } - public async Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { var action = _replyActions.Dequeue(); return (ResponseMessage)await action.GetEffectiveMessageAsync().ConfigureAwait(false); } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { _sentMessages.Add(message); } - public Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { _sentMessages.Add(message); return Task.CompletedTask; diff --git a/tests/MongoDB.Driver.Tests/AuthenticationTests.cs b/tests/MongoDB.Driver.Tests/AuthenticationTests.cs index 0ad423e16eb..56e644f6f1a 100644 --- a/tests/MongoDB.Driver.Tests/AuthenticationTests.cs +++ b/tests/MongoDB.Driver.Tests/AuthenticationTests.cs @@ -16,7 +16,6 @@ using System; using System.Linq; using System.Security.Cryptography.X509Certificates; -using System.Threading; using FluentAssertions; using MongoDB.Bson; using MongoDB.Driver.Core.Clusters.ServerSelectors; diff --git a/tests/MongoDB.Driver.Tests/ClusterTests.cs b/tests/MongoDB.Driver.Tests/ClusterTests.cs index e1a453acefa..a7c31d484c2 100644 --- a/tests/MongoDB.Driver.Tests/ClusterTests.cs +++ b/tests/MongoDB.Driver.Tests/ClusterTests.cs @@ -91,7 +91,6 @@ public void SelectServer_loadbalancing_prose_test([Values(false, true)] bool asy var fastServer = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new DelegateServerSelector((_, servers) => servers.Where(s => s.ServerId != slowServer.ServerId))); using var failPoint = FailPoint.Configure(slowServer, NoCoreSession.NewHandle(), failCommand, async); - var database = client.GetDatabase(_databaseName); CreateCollection(); var collection = database.GetCollection(_collectionName); diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs index 62bd5cfd535..bd6737efdb1 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs @@ -27,13 +27,6 @@ namespace MongoDB.Driver.Core.Bindings { public class ServerChannelSourceTests { - private Mock _mockServer; - - public ServerChannelSourceTests() - { - _mockServer = new Mock(); - } - [Fact] public void Constructor_should_throw_when_server_is_null() { @@ -47,7 +40,9 @@ public void Constructor_should_throw_when_server_is_null() [Fact] public void Constructor_should_throw_when_session_is_null() { - var exception = Record.Exception(() => new ServerChannelSource(_mockServer.Object, null)); + var server = Mock.Of(); + + var exception = Record.Exception(() => new ServerChannelSource(server, null)); exception.Should().BeOfType() .Subject.ParamName.Should().Be("session"); @@ -57,12 +52,11 @@ public void Constructor_should_throw_when_session_is_null() public void ServerDescription_should_return_description_of_server() { var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); - var desc = ServerDescriptionHelper.Disconnected(new ClusterId()); + var serverMock = new Mock(); + serverMock.SetupGet(s => s.Description).Returns(desc); - _mockServer.SetupGet(s => s.Description).Returns(desc); - + var subject = new ServerChannelSource(serverMock.Object, session); var result = subject.ServerDescription; result.Should().BeSameAs(desc); @@ -72,7 +66,7 @@ public void ServerDescription_should_return_description_of_server() public void Session_should_return_expected_result() { var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(Mock.Of(), session); var result = subject.Session; @@ -86,7 +80,7 @@ public async Task GetChannel_should_throw_if_disposed( bool async) { var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(Mock.Of(), session); subject.Dispose(); var exception = async ? @@ -102,20 +96,21 @@ public async Task GetChannel_should_get_connection_from_server( [Values(false, true)] bool async) { + var serverMock = new Mock(); var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(serverMock.Object, session); if (async) { await subject.GetChannelAsync(OperationContext.NoTimeout); - _mockServer.Verify(s => s.GetChannelAsync(It.IsAny()), Times.Once); + serverMock.Verify(s => s.GetChannelAsync(It.IsAny()), Times.Once); } else { subject.GetChannel(OperationContext.NoTimeout); - _mockServer.Verify(s => s.GetChannel(It.IsAny()), Times.Once); + serverMock.Verify(s => s.GetChannel(It.IsAny()), Times.Once); } } @@ -123,7 +118,7 @@ public async Task GetChannel_should_get_connection_from_server( public void Dispose_should_dispose_session() { var mockSession = new Mock(); - var subject = new ServerChannelSource(_mockServer.Object, mockSession.Object); + var subject = new ServerChannelSource(Mock.Of(), mockSession.Object); subject.Dispose(); diff --git a/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs b/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs index e467384ef70..7da9b2f2627 100644 --- a/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs @@ -312,10 +312,10 @@ public async Task AcquireConnection_should_invoke_error_handling_before_releasin .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Throws(exception); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Throws(exception); return connectionMock.Object; @@ -582,7 +582,7 @@ public void AcquireConnection_should_timeout_when_non_sufficient_reused_connecti .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { if (establishingCount.CurrentCount > 0) @@ -593,7 +593,7 @@ public void AcquireConnection_should_timeout_when_non_sufficient_reused_connecti blockEstablishmentEvent.Wait(); }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { if (establishingCount.CurrentCount > 0) @@ -756,14 +756,14 @@ public void Acquire_and_release_connection_stress_test( .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { var sleepMS = random.Next(minEstablishingTime, maxEstablishingTime); Thread.Sleep(sleepMS); }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(async () => { var sleepMS = random.Next(minEstablishingTime, maxEstablishingTime); @@ -970,7 +970,7 @@ public void In_use_marker_should_work_as_expected( var mockConnection = new Mock(); mockConnection.SetupGet(c => c.ConnectionId).Returns(new ConnectionId(serverId, ci)); mockConnection - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { if (minPoolSize == 0 || ci == 2) // ignore connection 1 created in minPoolSize logic @@ -984,7 +984,7 @@ public void In_use_marker_should_work_as_expected( }); mockConnection - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(async () => { if (minPoolSize == 0 || ci == 2) // ignore connection 1 created in minPoolSize logic @@ -1076,7 +1076,7 @@ public void Maintenance_should_call_connection_dispose_when_connection_authentic var authenticationException = new MongoAuthenticationException(connectionId, "test message"); var authenticationFailedConnection = new Mock(); authenticationFailedConnection - .Setup(c => c.Open(It.IsAny())) // an authentication exception is thrown from _connectionInitializer.InitializeConnection + .Setup(c => c.Open(It.IsAny())) // an authentication exception is thrown from _connectionInitializer.InitializeConnection // that in turn is called from OpenAsync .Throws(authenticationException); authenticationFailedConnection.SetupGet(c => c.ConnectionId).Returns(connectionId); @@ -1166,7 +1166,7 @@ public void MaxConnecting_queue_should_be_cleared_on_pool_clear( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allEstablishing.Signal(); @@ -1174,7 +1174,7 @@ public void MaxConnecting_queue_should_be_cleared_on_pool_clear( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allEstablishing.Signal(); @@ -1424,7 +1424,7 @@ public void WaitQueue_should_throw_when_full( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allAcquiringCountdownEvent.Signal(); @@ -1432,7 +1432,7 @@ public void WaitQueue_should_throw_when_full( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allAcquiringCountdownEvent.Signal(); @@ -1516,7 +1516,7 @@ public void WaitQueue_should_be_cleared_on_pool_clear( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allEstablishing.Signal(); @@ -1524,7 +1524,7 @@ public void WaitQueue_should_be_cleared_on_pool_clear( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allEstablishing.Signal(); diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs index 8a0ae8b75d6..393d19c0b38 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs @@ -103,7 +103,7 @@ public void Dispose_should_raise_the_correct_events() [Theory] [ParameterAttributeData] - public void Open_should_always_create_description_if_handshake_was_successful([Values(false, true)] bool async) + public async Task Open_should_always_create_description_if_handshake_was_successful([Values(false, true)] bool async) { var serviceId = ObjectId.GenerateNewId(); var connectionDescription = new ConnectionDescription( @@ -124,15 +124,9 @@ public void Open_should_always_create_description_if_handshake_was_successful([V .Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), CancellationToken.None)) .ThrowsAsync(socketException); - Exception exception; - if (async) - { - exception = Record.Exception(() => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult()); - } - else - { - exception = Record.Exception(() => _subject.Open(CancellationToken.None)); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); _subject.Description.Should().Be(connectionDescription); var ex = exception.Should().BeOfType().Subject; @@ -185,11 +179,11 @@ public async Task Open_should_create_authenticators_only_once( if (async) { - await subject.OpenAsync(CancellationToken.None); + await subject.OpenAsync(OperationContext.NoTimeout); } else { - subject.Open(CancellationToken.None); + subject.Open(OperationContext.NoTimeout); } authenticatorFactoryMock.Verify(f => f.Create(), Times.Once()); @@ -206,52 +200,37 @@ ResponseMessage CreateResponseMessage() [Theory] [ParameterAttributeData] - public void Open_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( + public async Task Open_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( [Values(false, true)] bool async) { _subject.Dispose(); - Action act; - if (async) - { - act = () => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.Open(CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void Open_should_raise_the_correct_events_upon_failure( + public async Task Open_should_raise_the_correct_events_upon_failure( [Values(false, true)] bool async) { - Action act; - if (async) - { - var result = new TaskCompletionSource(); - result.SetException(new SocketException()); - _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) - .Returns(result.Task); - - act = () => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - _mockConnectionInitializer.Setup(i => i.SendHello(It.IsAny(), It.IsAny())) - .Throws(); + _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) + .Throws(); + _mockConnectionInitializer.Setup(i => i.SendHello(It.IsAny(), It.IsAny())) + .Throws(); - act = () => _subject.Open(CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); - act.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception.InnerException.Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -260,17 +239,17 @@ public void Open_should_raise_the_correct_events_upon_failure( [Theory] [ParameterAttributeData] - public void Open_should_setup_the_description( + public async Task Open_should_setup_the_description( [Values(false, true)] bool async) { if (async) { - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); + await _subject.OpenAsync(OperationContext.NoTimeout); } else { - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); } _subject.Description.Should().NotBeNull(); @@ -290,32 +269,27 @@ public void Open_should_not_complete_the_second_call_until_the_first_is_complete { var task1IsBlocked = false; var completionSource = new TaskCompletionSource(); - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(() => { task1IsBlocked = true; return completionSource.Task.GetAwaiter().GetResult(); }); - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(() => { task1IsBlocked = true; return completionSource.Task; }); - - Task openTask1; - if (async1) - { + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(() => + { + task1IsBlocked = true; + return completionSource.Task.GetAwaiter().GetResult(); + }); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .Returns(() => + { + task1IsBlocked = true; + return completionSource.Task; + }); - openTask1 = _subject.OpenAsync(CancellationToken.None); - } - else - { - openTask1 = Task.Run(() => _subject.Open(CancellationToken.None)); - } + var openTask1 = async1 ? + _subject.OpenAsync(OperationContext.NoTimeout) : + Task.Run(() => _subject.Open(OperationContext.NoTimeout)); SpinWait.SpinUntil(() => task1IsBlocked, TimeSpan.FromSeconds(5)).Should().BeTrue(); - Task openTask2; - if (async2) - { - openTask2 = _subject.OpenAsync(CancellationToken.None); - } - else - { - openTask2 = Task.Run(() => _subject.Open(CancellationToken.None)); - } + var openTask2 = async2 ? + _subject.OpenAsync(OperationContext.NoTimeout) : + Task.Run(() => _subject.Open(OperationContext.NoTimeout)); openTask1.IsCompleted.Should().BeFalse(); openTask2.IsCompleted.Should().BeFalse(); @@ -340,11 +314,11 @@ public async Task Reauthentication_should_use_the_same_auth_context_as_in_initia if (async) { - await _subject.OpenAsync(CancellationToken.None); + await _subject.OpenAsync(OperationContext.NoTimeout); } else { - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); } _subject._connectionInitializerContext().Should().Be(_connectionInitializerContextAfterAuthentication); @@ -365,7 +339,7 @@ public async Task Reauthentication_should_use_the_same_auth_context_as_in_initia [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_invalid_size( + public async Task ReceiveMessage_should_throw_a_FormatException_when_message_is_an_invalid_size( [Values(-1, 48000001)] int length, [Values(false, true)] @@ -380,27 +354,15 @@ public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_inv } stream.Write(bytes, 0, bytes.Length); stream.Seek(0, SeekOrigin.Begin); + + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())).Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Exception exception; - if (async) - { - _mockStreamFactory - .Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .ReturnsAsync(stream); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - exception = Record - .Exception(() => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None) - .GetAwaiter() - .GetResult()); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - exception = Record.Exception(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); var e = exception.InnerException.Should().BeOfType().Subject; @@ -410,71 +372,52 @@ public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_inv [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_ArgumentNullException_when_the_encoderSelector_is_null( + public async Task ReceiveMessage_should_throw_an_ArgumentNullException_when_the_encoderSelector_is_null( [Values(false, true)] bool async) { - IMessageEncoderSelector encoderSelector = null; - - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, null, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, null, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("encoderSelector"); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( + public async Task ReceiveMessage_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( [Values(false, true)] bool async) { var encoderSelector = new Mock().Object; _subject.Dispose(); - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_InvalidOperationException_if_the_connection_is_not_open( + public async Task ReceiveMessage_should_throw_an_InvalidOperationException_if_the_connection_is_not_open( [Values(false, true)] bool async) { var encoderSelector = new Mock().Object; - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( + public async Task ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( [Values(false, true)] bool async) { @@ -483,27 +426,18 @@ public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( var messageToReceive = MessageHelper.BuildReply(new BsonDocument(), BsonDocumentSerializer.Instance, responseTo: 10); MessageHelper.WriteResponsesToStream(stream, messageToReceive); - var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - - ResponseMessage received; - if (async) - { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(Task.FromResult(stream)); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); + _capturedEvents.Clear(); - received = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); + var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - received = _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var received = async ? + await _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings); var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received }); @@ -518,40 +452,31 @@ public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stream( + public async Task ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stream( [Values(false, true)] bool async) { using (var stream = new BlockingMemoryStream()) { - var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - - Task receiveMessageTask; - if (async) - { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(Task.FromResult(stream)); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); + _capturedEvents.Clear(); - receiveMessageTask = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); + var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - receiveMessageTask = Task.Run(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receiveMessageTask = async ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); receiveMessageTask.IsCompleted.Should().BeFalse(); var messageToReceive = MessageHelper.BuildReply(new BsonDocument(), BsonDocumentSerializer.Instance, responseTo: 10); MessageHelper.WriteResponsesToStream(stream, messageToReceive); - var received = receiveMessageTask.GetAwaiter().GetResult(); + var received = await receiveMessageTask; var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received }); @@ -566,7 +491,7 @@ public void ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stre [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_handle_out_of_order_replies( + public async Task ReceiveMessage_should_handle_out_of_order_replies( [Values(false, true)] bool async1, [Values(false, true)] @@ -574,32 +499,19 @@ public void ReceiveMessage_should_handle_out_of_order_replies( { using (var stream = new BlockingMemoryStream()) { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); _capturedEvents.Clear(); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Task receivedTask10; - if (async1) - { - receivedTask10 = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - receivedTask10 = Task.Run(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receivedTask10 = async1 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - Task receivedTask11; - if (async2) - { - receivedTask11 = _subject.ReceiveMessageAsync(11, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - receivedTask11 = Task.Run(() => _subject.ReceiveMessage(11, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receivedTask11 = async2 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 11, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 11, encoderSelector, _messageEncoderSettings)); SpinWait.SpinUntil(() => _capturedEvents.Count >= 2, TimeSpan.FromSeconds(5)).Should().BeTrue(); @@ -607,8 +519,8 @@ public void ReceiveMessage_should_handle_out_of_order_replies( var messageToReceive11 = MessageHelper.BuildReply(new BsonDocument("_id", 11), BsonDocumentSerializer.Instance, responseTo: 11); MessageHelper.WriteResponsesToStream(stream, messageToReceive11, messageToReceive10); // out of order - var received10 = receivedTask10.GetAwaiter().GetResult(); - var received11 = receivedTask11.GetAwaiter().GetResult(); + var received10 = await receivedTask10; + var received11 = await receivedTask11; var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive10, messageToReceive11 }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received10, received11 }); @@ -649,9 +561,9 @@ public async Task ReceiveMessage_should_not_produce_unobserved_task_exceptions_o tcs.SetException(new SocketException()); SetupStreamRead(mockStream, tcs); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); - var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); + var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); GC.Collect(); // Collects the unobserved tasks @@ -685,14 +597,14 @@ public async Task ReceiveMessageAsync_should_not_produce_unobserved_task_excepti var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); _mockStreamFactory - .Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) + .Setup(f => f.CreateStream(_endPoint, It.IsAny())) .Returns(mockStream.Object); var tcs = new TaskCompletionSource(); SetupStreamRead(mockStream, tcs, 50); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); - var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); + var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); exception.InnerException.Should().BeOfType(); @@ -715,7 +627,7 @@ public async Task ReceiveMessageAsync_should_not_produce_unobserved_task_excepti [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( + public async Task ReceiveMessage_should_throw_network_exception_to_all_awaiters( [Values(false, true)] bool async1, [Values(false, true)] @@ -726,46 +638,35 @@ public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( { var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) .Returns(mockStream.Object); var readTcs = new TaskCompletionSource(); SetupStreamRead(mockStream, readTcs, readTimeoutMs: Timeout.Infinite); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); _capturedEvents.Clear(); - Task task1; - if (async1) - { - task1 = _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, It.IsAny()); - } - else - { - task1 = Task.Run(() => _subject.ReceiveMessage(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var task1 = async1 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); - Task task2; - if (async2) - { - task2 = _subject.ReceiveMessageAsync(2, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - task2 = Task.Run(() => _subject.ReceiveMessage(2, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var task2 = async2 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)); SpinWait.SpinUntil(() => _capturedEvents.Count >= 2, TimeSpan.FromSeconds(5)).Should().BeTrue(); readTcs.SetException(new SocketException()); - Func act1 = () => task1; - act1.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + var exception1 = await Record.ExceptionAsync(() => task1); + var exception2 = await Record.ExceptionAsync(() => task2); - Func act2 = () => task2; - act2.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception1.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception1.InnerException.Should().BeOfType(); + + exception2.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception2.InnerException.Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -778,7 +679,7 @@ public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_MongoConnectionClosedException_when_connection_has_failed( + public async Task ReceiveMessage_should_throw_MongoConnectionClosedException_when_connection_has_failed( [Values(false, true)] bool async1, [Values(false, true)] @@ -787,42 +688,29 @@ public void ReceiveMessage_should_throw_MongoConnectionClosedException_when_conn var mockStream = new Mock(); using (mockStream.Object) { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(mockStream.Object); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(mockStream.Object); var readTcs = new TaskCompletionSource(); readTcs.SetException(new SocketException()); SetupStreamRead(mockStream, readTcs); - _subject.Open(CancellationToken.None); + await _subject.OpenAsync(OperationContext.NoTimeout); _capturedEvents.Clear(); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Action act1; - if (async1) - { - act1 = () => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act1 = () => _subject.ReceiveMessage(1, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception1 = async1 ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); - Action act2; - if (async2) - { - act2 = () => _subject.ReceiveMessageAsync(2, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act2 = () => _subject.ReceiveMessage(2, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception2 = async2 ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)); - act1.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception1.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception1.InnerException.Should().BeOfType(); - act2.ShouldThrow() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception2.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -838,8 +726,8 @@ public async Task SendMessage_should_throw_an_ArgumentNullException_if_message_i bool async) { var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(null, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(null, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, null, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, null, _messageEncoderSettings)); exception.Should().BeOfType(); } @@ -854,8 +742,8 @@ public async Task SendMessage_should_throw_an_ObjectDisposedException_if_the_con _subject.Dispose(); var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings)); exception.Should().BeOfType(); } @@ -869,39 +757,34 @@ public async Task SendMessage_should_throw_an_InvalidOperationException_if_the_c var message = MessageHelper.BuildQuery(); var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void SendMessage_should_put_the_message_on_the_stream_and_raise_the_correct_events( + public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the_correct_events( [Values(false, true)] bool async) { using (var stream = new MemoryStream()) { + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())).Returns(stream); + _subject.OpenAsync(OperationContext.NoTimeout).GetAwaiter().GetResult(); + _capturedEvents.Clear(); + var message = MessageHelper.BuildQuery(query: new BsonDocument("x", 1)); if (async) { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .ReturnsAsync(stream); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); - - _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); + await _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings); } else { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); - - _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None); + _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings); } var expectedRequests = MessageHelper.TranslateMessagesToBsonDocuments(new[] { message }); diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs index a01da9bb54a..4ebb258ee9f 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs @@ -87,9 +87,9 @@ public BinaryConnection_CommandEventTests(ITestOutputHelper output) : base(outpu new HelloResult(new BsonDocument { { "maxWireVersion", WireVersion.Server36 } })); _mockConnectionInitializer = new Mock(); - _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), CancellationToken.None)) + _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) .Returns(() => Task.FromResult(new ConnectionInitializerContext(connectionDescriptionFunc(), null))); - _mockConnectionInitializer.Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), CancellationToken.None)) + _mockConnectionInitializer.Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => Task.FromResult(new ConnectionInitializerContext(connectionDescriptionFunc(), null))); _subject = new BinaryConnection( @@ -102,9 +102,9 @@ public BinaryConnection_CommandEventTests(ITestOutputHelper output) : base(outpu LoggerFactory); _stream = new BlockingMemoryStream(); - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) .Returns(Task.FromResult(_stream)); - _subject.OpenAsync(CancellationToken.None).Wait(); + _subject.OpenAsync(OperationContext.NoTimeout).Wait(); _capturedEvents.Clear(); _operationIdDisposer = EventContext.BeginOperation(); @@ -484,14 +484,14 @@ public void Should_process_a_failed_query() private void SendMessage(RequestMessage message) { - _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None).Wait(); + _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings).Wait(); } private void ReceiveMessage(ReplyMessage message) { MessageHelper.WriteResponsesToStream(_stream, message); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - _subject.ReceiveMessageAsync(message.ResponseTo, encoderSelector, _messageEncoderSettings, CancellationToken.None).Wait(); + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, message.ResponseTo, encoderSelector, _messageEncoderSettings).Wait(); } } } diff --git a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs index 8be118b9db5..fc7cac976c8 100644 --- a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs @@ -353,7 +353,7 @@ void SetupFailedConnection(Mock mockFaultyConnection) () => WaitForTaskOrTimeout(hasClusterBeenDisposed.Task, TimeSpan.FromMinutes(1), "cluster dispose") }); mockFaultyConnection - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { var responseAction = faultyConnectionResponses.Dequeue(); @@ -361,7 +361,7 @@ void SetupFailedConnection(Mock mockFaultyConnection) }); mockFaultyConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => { WaitForTaskOrTimeout( @@ -374,13 +374,13 @@ void SetupFailedConnection(Mock mockFaultyConnection) void SetupHealthyConnection(Mock mockHealthyConnection) { - mockHealthyConnection.Setup(c => c.Open(It.IsAny())); // no action is required - mockHealthyConnection.Setup(c => c.OpenAsync(It.IsAny())).Returns(Task.FromResult(true)); // no action is required + mockHealthyConnection.Setup(c => c.Open(It.IsAny())); // no action is required + mockHealthyConnection.Setup(c => c.OpenAsync(It.IsAny())).Returns(Task.FromResult(true)); // no action is required mockHealthyConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(commandResponseAction); mockConnection - .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(commandResponseAction); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs index aa944ed0ca6..cfea173c619 100644 --- a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs @@ -303,9 +303,9 @@ private void SetupServerMonitorConnection( .SetupGet(c => c.Description) .Returns(GetConnectionDescription); - mockConnection.Setup(c => c.Open(It.IsAny())); // no action is required + mockConnection.Setup(c => c.Open(It.IsAny())); // no action is required mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(GetHelloResponse); ResponseMessage GetHelloResponse() diff --git a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs index bf17fa539c6..baa070ff455 100644 --- a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs @@ -14,10 +14,7 @@ */ using System; -using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -31,272 +28,413 @@ namespace MongoDB.Driver.Core.Misc public class StreamExtensionMethodsTests { [Theory] - [InlineData(0, new byte[] { 0, 0 })] - [InlineData(1, new byte[] { 1, 0 })] - [InlineData(2, new byte[] { 1, 2 })] - public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for_count(int count, byte[] expectedBytes) + [InlineData(true, 0, new byte[] { 0, 0 })] + [InlineData(true, 1, new byte[] { 1, 0 })] + [InlineData(true, 2, new byte[] { 1, 2 })] + [InlineData(false, 0, new byte[] { 0, 0 })] + [InlineData(false, 1, new byte[] { 1, 0 })] + [InlineData(false, 2, new byte[] { 1, 2 })] + public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_count(bool async, int count, byte[] expectedBytes) { var bytes = new byte[] { 1, 2 }; var stream = new MemoryStream(bytes); var destination = new byte[2]; - await stream.ReadBytesAsync(destination, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, count); + } + else + { + stream.ReadBytes(OperationContext.NoTimeout, destination, 0, count); + } destination.Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new byte[] { 0, 1, 0 })] - [InlineData(2, new byte[] { 0, 0, 1 })] - public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for_offset(int offset, byte[] expectedBytes) + [InlineData(true, 1, new byte[] { 0, 1, 0 })] + [InlineData(true, 2, new byte[] { 0, 0, 1 })] + [InlineData(false, 1, new byte[] { 0, 1, 0 })] + [InlineData(false, 2, new byte[] { 0, 0, 1 })] + public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_offset(bool async, int offset, byte[] expectedBytes) { var bytes = new byte[] { 1 }; var stream = new MemoryStream(bytes); var destination = new byte[3]; - await stream.ReadBytesAsync(destination, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 1); + } + else + { + stream.ReadBytes(OperationContext.NoTimeout, destination, offset, 1); + } destination.Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new[] { 3 })] - [InlineData(2, new[] { 1, 2 })] - [InlineData(3, new[] { 2, 1 })] - [InlineData(4, new[] { 1, 1, 1 })] - public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for_partial_reads(int testCase, int[] partition) + [InlineData(true, 1, new[] { 3 })] + [InlineData(true, 2, new[] { 1, 2 })] + [InlineData(true, 3, new[] { 2, 1 })] + [InlineData(true, 4, new[] { 1, 1, 1 })] + [InlineData(false, 1, new[] { 3 })] + [InlineData(false, 2, new[] { 1, 2 })] + [InlineData(false, 3, new[] { 2, 1 })] + [InlineData(false, 4, new[] { 1, 1, 1 })] + public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_partial_reads(bool async, int testCase, int[] partition) { var mockStream = new Mock(); var bytes = new byte[] { 1, 2, 3 }; var n = 0; var position = 0; + Task ReadPartial (byte[] buffer, int offset, int count) + { + var length = partition[n++]; + Buffer.BlockCopy(bytes, position, buffer, offset, length); + position += length; + return Task.FromResult(length); + } + mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => - { - var length = partition[n++]; - Buffer.BlockCopy(bytes, position, buffer, offset, length); - position += length; - return Task.FromResult(length); - }); + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); + mockStream.Setup(s => s.EndRead(It.IsAny())) + .Returns(x => ((Task)x).GetAwaiter().GetResult()); var destination = new byte[3]; - await mockStream.Object.ReadBytesAsync(destination, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 3); + } + else + { + mockStream.Object.ReadBytes(OperationContext.NoTimeout, destination, 0, 3); + } destination.Should().Equal(bytes); } - [Fact] - public void ReadBytesAsync_with_byte_array_should_throw_when_end_of_stream_is_reached() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_array_should_throw_when_end_of_stream_is_reached([Values(true, false)]bool async) { var mockStream = new Mock(); var destination = new byte[1]; - mockStream.Setup(s => s.ReadAsync(destination, 0, 1, It.IsAny())).Returns(Task.FromResult(0)); + mockStream.Setup(s => s.ReadAsync(destination, 0, 1, It.IsAny())) + .ReturnsAsync(0); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(0)); - Func action = () => mockStream.Object.ReadBytesAsync(destination, 0, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 1)) : + Record.Exception(() => mockStream.Object.ReadBytes(OperationContext.NoTimeout, destination, 0, 1)); - action.ShouldThrow(); + exception.Should().BeOfType(); } - [Fact] - public void ReadBytesAsync_with_byte_array_should_throw_when_buffer_is_null() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_array_should_throw_when_buffer_is_null([Values(true, false)]bool async) { var stream = new Mock().Object; byte[] destination = null; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("buffer"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("buffer"); } [Theory] - [InlineData(0, -1)] - [InlineData(1, 2)] - [InlineData(2, 1)] - public void ReadBytesAsync_with_byte_array_should_throw_when_count_is_invalid(int offset, int count) + [InlineData(true, 0, -1)] + [InlineData(true, 1, 2)] + [InlineData(true, 2, 1)] + [InlineData(false, 0, -1)] + [InlineData(false, 1, 2)] + [InlineData(false, 2, 1)] + public async Task ReadBytes_with_byte_array_should_throw_when_count_is_invalid(bool async, int offset, int count) { var stream = new Mock().Object; var destination = new byte[2]; - Func action = () => stream.ReadBytesAsync(destination, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, count)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, offset, count)); - action.ShouldThrow().And.ParamName.Should().Be("count"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("count"); } [Theory] [ParameterAttributeData] - public void ReadBytesAsync_with_byte_array_should_throw_when_offset_is_invalid( - [Values(-1, 3)] - int offset) + public async Task ReadBytes_with_byte_array_should_throw_when_offset_is_invalid( + [Values(true, false)]bool async, + [Values(-1, 3)]int offset) { var stream = new Mock().Object; var destination = new byte[2]; - Func action = () => stream.ReadBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, offset, 0)); - action.ShouldThrow().And.ParamName.Should().Be("offset"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("offset"); } - [Fact] - public void ReadBytesAsync_with_byte_array_should_throw_when_stream_is_null() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_array_should_throw_when_stream_is_null([Values(true, false)]bool async) { Stream stream = null; var destination = new byte[0]; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("stream"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("stream"); } [Theory] - [InlineData(0, new byte[] { 0, 0 })] - [InlineData(1, new byte[] { 1, 0 })] - [InlineData(2, new byte[] { 1, 2 })] - public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_for_count(int count, byte[] expectedBytes) + [InlineData(true, 0, new byte[] { 0, 0 })] + [InlineData(true, 1, new byte[] { 1, 0 })] + [InlineData(true, 2, new byte[] { 1, 2 })] + [InlineData(false, 0, new byte[] { 0, 0 })] + [InlineData(false, 1, new byte[] { 1, 0 })] + [InlineData(false, 2, new byte[] { 1, 2 })] + public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_count(bool async, int count, byte[] expectedBytes) { var bytes = new byte[] { 1, 2 }; var stream = new MemoryStream(bytes); var destination = new ByteArrayBuffer(new byte[2]); - await stream.ReadBytesAsync(destination, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, count); + } + else + { + stream.ReadBytes(OperationContext.NoTimeout, destination, 0, count); + } destination.AccessBackingBytes(0).Array.Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new byte[] { 0, 1, 0 })] - [InlineData(2, new byte[] { 0, 0, 1 })] - public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_for_offset(int offset, byte[] expectedBytes) + [InlineData(true, 1, new byte[] { 0, 1, 0 })] + [InlineData(true, 2, new byte[] { 0, 0, 1 })] + [InlineData(false, 1, new byte[] { 0, 1, 0 })] + [InlineData(false, 2, new byte[] { 0, 0, 1 })] + public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_offset(bool async, int offset, byte[] expectedBytes) { var bytes = new byte[] { 1 }; var stream = new MemoryStream(bytes); var destination = new ByteArrayBuffer(new byte[3]); - await stream.ReadBytesAsync(destination, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 1); + } + else + { + stream.ReadBytes(OperationContext.NoTimeout, destination, offset, 1); + } destination.AccessBackingBytes(0).Array.Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new[] { 3 })] - [InlineData(2, new[] { 1, 2 })] - [InlineData(3, new[] { 2, 1 })] - [InlineData(4, new[] { 1, 1, 1 })] - public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_for_partial_reads(int testCase, int[] partition) + [InlineData(true, 1, new[] { 3 })] + [InlineData(true, 2, new[] { 1, 2 })] + [InlineData(true, 3, new[] { 2, 1 })] + [InlineData(true, 4, new[] { 1, 1, 1 })] + [InlineData(false, 1, new[] { 3 })] + [InlineData(false, 2, new[] { 1, 2 })] + [InlineData(false, 3, new[] { 2, 1 })] + [InlineData(false, 4, new[] { 1, 1, 1 })] + public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_partial_reads(bool async, int testCase, int[] partition) { var bytes = new byte[] { 1, 2, 3 }; var mockStream = new Mock(); var destination = new ByteArrayBuffer(new byte[3], 3); var n = 0; var position = 0; - mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => - { - var length = partition[n++]; - Buffer.BlockCopy(bytes, position, buffer, offset, length); - position += length; - return Task.FromResult(length); - }); + Task ReadPartial (byte[] buffer, int offset, int count) + { + var length = partition[n++]; + Buffer.BlockCopy(bytes, position, buffer, offset, length); + position += length; + return Task.FromResult(length); + } - await mockStream.Object.ReadBytesAsync(destination, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); + mockStream.Setup(s => s.EndRead(It.IsAny())) + .Returns(x => ((Task)x).GetAwaiter().GetResult()); + + if (async) + { + await mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 3); + } + else + { + mockStream.Object.ReadBytes(OperationContext.NoTimeout, destination, 0, 3); + } destination.AccessBackingBytes(0).Array.Should().Equal(bytes); } - [Fact] - public void ReadBytesAsync_with_byte_buffer_should_throw_when_end_of_stream_is_reached() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_buffer_should_throw_when_end_of_stream_is_reached([Values(true, false)]bool async) { var mockStream = new Mock(); var destination = CreateMockByteBuffer(1).Object; - mockStream.Setup(s => s.ReadAsync(It.IsAny(), 0, 1, It.IsAny())).Returns(Task.FromResult(0)); + mockStream.Setup(s => s.ReadAsync(It.IsAny(), 0, 1, It.IsAny())) + .ReturnsAsync(0); + mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(0)); - Func action = () => mockStream.Object.ReadBytesAsync(destination, 0, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 1)) : + Record.Exception(() => mockStream.Object.ReadBytes(OperationContext.NoTimeout, destination, 0, 1)); - action.ShouldThrow(); + exception.Should().BeOfType(); } - [Fact] - public void ReadBytesAsync_with_byte_buffer_should_throw_when_buffer_is_null() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_buffer_should_throw_when_buffer_is_null([Values(true, false)]bool async) { var stream = new Mock().Object; IByteBuffer destination = null; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("buffer"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("buffer"); } [Theory] - [InlineData(0, -1)] - [InlineData(1, 2)] - [InlineData(2, 1)] - public void ReadBytesAsync_with_byte_buffer_should_throw_when_count_is_invalid(int offset, int count) + [InlineData(true, 0, -1)] + [InlineData(true, 1, 2)] + [InlineData(true, 2, 1)] + [InlineData(false, 0, -1)] + [InlineData(false, 1, 2)] + [InlineData(false, 2, 1)] + public async Task ReadBytes_with_byte_buffer_should_throw_when_count_is_invalid(bool async, int offset, int count) { var stream = new Mock().Object; var destination = CreateMockByteBuffer(2).Object; - Func action = () => stream.ReadBytesAsync(destination, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, count)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, offset, count)); - action.ShouldThrow().And.ParamName.Should().Be("count"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("count"); } [Theory] [ParameterAttributeData] - public void ReadBytesAsync_with_byte_buffer_should_throw_when_offset_is_invalid( - [Values(-1, 3)] - int offset) + public async Task ReadBytes_with_byte_buffer_should_throw_when_offset_is_invalid( + [Values(true, false)] bool async, + [Values(-1, 3)]int offset) { var stream = new Mock().Object; var destination = CreateMockByteBuffer(2).Object; - Func action = () => stream.ReadBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, offset, 0)); - action.ShouldThrow().And.ParamName.Should().Be("offset"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("offset"); } - [Fact] - public void ReadBytesAsync_with_byte_buffer_should_throw_when_stream_is_null() + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_buffer_should_throw_when_stream_is_null([Values(true, false)]bool async) { Stream stream = null; var destination = new Mock().Object; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0)) : + Record.Exception(() => stream.ReadBytes(OperationContext.NoTimeout, destination, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("stream"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("stream"); } [Theory] - [InlineData(0, new byte[] { })] - [InlineData(1, new byte[] { 1 })] - [InlineData(2, new byte[] { 1, 2 })] - public async Task WriteBytesAsync_should_have_expected_effect_for_count(int count, byte[] expectedBytes) + [InlineData(true, 0, new byte[] { })] + [InlineData(true, 1, new byte[] { 1 })] + [InlineData(true, 2, new byte[] { 1, 2 })] + [InlineData(false, 0, new byte[] { })] + [InlineData(false, 1, new byte[] { 1 })] + [InlineData(false, 2, new byte[] { 1, 2 })] + public async Task WriteBytes_should_have_expected_effect_for_count(bool async, int count, byte[] expectedBytes) { var stream = new MemoryStream(); var source = new ByteArrayBuffer(new byte[] { 1, 2 }); - await stream.WriteBytesAsync(source, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.WriteBytesAsync(OperationContext.NoTimeout, source, 0, count); + } + else + { + stream.WriteBytes(OperationContext.NoTimeout, source, 0, count); + } stream.ToArray().Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new byte[] { 2 })] - [InlineData(2, new byte[] { 3 })] - public async Task WriteBytesAsync_should_have_expected_effect_for_offset(int offset, byte[] expectedBytes) + [InlineData(true, 1, new byte[] { 2 })] + [InlineData(true, 2, new byte[] { 3 })] + [InlineData(false, 1, new byte[] { 2 })] + [InlineData(false, 2, new byte[] { 3 })] + public async Task WriteBytes_should_have_expected_effect_for_offset(bool async, int offset, byte[] expectedBytes) { var stream = new MemoryStream(); var source = new ByteArrayBuffer(new byte[] { 1, 2, 3 }); - await stream.WriteBytesAsync(source, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.WriteBytesAsync(OperationContext.NoTimeout, source, offset, 1); + } + else + { + stream.WriteBytes(OperationContext.NoTimeout, source, offset, 1); + } stream.ToArray().Should().Equal(expectedBytes); } [Theory] - [InlineData(1, new[] { 3 })] - [InlineData(2, new[] { 1, 2 })] - [InlineData(3, new[] { 2, 1 })] - [InlineData(4, new[] { 1, 1, 1 })] - public async Task WriteBytesAsync_should_have_expected_effect_for_partial_writes(int testCase, int[] partition) + [InlineData(true, 1, new[] { 3 })] + [InlineData(true, 2, new[] { 1, 2 })] + [InlineData(true, 3, new[] { 2, 1 })] + [InlineData(true, 4, new[] { 1, 1, 1 })] + [InlineData(false, 1, new[] { 3 })] + [InlineData(false, 2, new[] { 1, 2 })] + [InlineData(false, 3, new[] { 2, 1 })] + [InlineData(false, 4, new[] { 1, 1, 1 })] + public async Task WriteBytes_should_have_expected_effect_for_partial_writes(bool async, int testCase, int[] partition) { var stream = new MemoryStream(); var mockSource = new Mock(); @@ -310,58 +448,82 @@ public async Task WriteBytesAsync_should_have_expected_effect_for_partial_writes return new ArraySegment(bytes, position, length); }); - await stream.WriteBytesAsync(mockSource.Object, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + if (async) + { + await stream.WriteBytesAsync(OperationContext.NoTimeout, mockSource.Object, 0, 3); + } + else + { + stream.WriteBytes(OperationContext.NoTimeout, mockSource.Object, 0, 3); + } stream.ToArray().Should().Equal(bytes); } - [Fact] - public void WriteBytesAsync_should_throw_when_buffer_is_null() + [Theory] + [ParameterAttributeData] + public async Task WriteBytes_should_throw_when_buffer_is_null([Values(true, false)]bool async) { var stream = new Mock().Object; - Func action = () => stream.WriteBytesAsync(null, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.WriteBytesAsync(OperationContext.NoTimeout, null, 0, 0)) : + Record.Exception(() => stream.WriteBytes(OperationContext.NoTimeout, null, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("buffer"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("buffer"); } [Theory] - [InlineData(0, -1)] - [InlineData(1, 2)] - [InlineData(2, 1)] - public void WriteBytesAsync_should_throw_when_count_is_invalid(int offset, int count) + [InlineData(true, 0, -1)] + [InlineData(true, 1, 2)] + [InlineData(true, 2, 1)] + [InlineData(false, 0, -1)] + [InlineData(false, 1, 2)] + [InlineData(false, 2, 1)] + public async Task WriteBytes_should_throw_when_count_is_invalid(bool async, int offset, int count) { var stream = new Mock().Object; var source = CreateMockByteBuffer(2).Object; - Func action = () => stream.WriteBytesAsync(source, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.WriteBytesAsync(OperationContext.NoTimeout, source, offset, count)) : + Record.Exception(() => stream.WriteBytes(OperationContext.NoTimeout, source, offset, count)); - action.ShouldThrow().And.ParamName.Should().Be("count"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("count"); } [Theory] [ParameterAttributeData] - public void WriteBytesAsync_should_throw_when_offset_is_invalid( - [Values(-1, 3)] - int offset) + public async Task WriteBytes_should_throw_when_offset_is_invalid( + [Values(true, false)]bool async, + [Values(-1, 3)]int offset) { var stream = new Mock().Object; - var destination = CreateMockByteBuffer(2).Object; + var source = CreateMockByteBuffer(2).Object; - Func action = () => stream.WriteBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.WriteBytesAsync(OperationContext.NoTimeout, source, offset, 0)) : + Record.Exception(() => stream.WriteBytes(OperationContext.NoTimeout, source, offset, 0)); - action.ShouldThrow().And.ParamName.Should().Be("offset"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("offset"); } - [Fact] - public void WriteBytesAsync_should_throw_when_stream_is_null() + [Theory] + [ParameterAttributeData] + public async Task WriteBytes_should_throw_when_stream_is_null([Values(true, false)]bool async) { Stream stream = null; var source = new Mock().Object; - Func action = () => stream.WriteBytesAsync(source, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + var exception = async ? + await Record.ExceptionAsync(() => stream.WriteBytesAsync(OperationContext.NoTimeout, source, 0, 0)) : + Record.Exception(() => stream.WriteBytes(OperationContext.NoTimeout, source, 0, 0)); - action.ShouldThrow().And.ParamName.Should().Be("stream"); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("stream"); } // helper methods diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs index a1b8943a276..3728762b355 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs @@ -432,6 +432,7 @@ public void GetMore_should_use_same_session( mockChannelSource.Setup(m => m.GetChannelAsync(It.IsAny())).Returns(Task.FromResult(channel)); mockChannel .Setup(m => m.CommandAsync( + It.IsAny(), session, null, databaseNamespace, @@ -442,8 +443,7 @@ public void GetMore_should_use_same_session( null, CommandResponseHandling.Return, It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Callback(() => sameSessionWasUsed = true) .Returns(Task.FromResult(secondBatch)); @@ -454,6 +454,7 @@ public void GetMore_should_use_same_session( mockChannelSource.Setup(m => m.GetChannel(It.IsAny())).Returns(channel); mockChannel .Setup(m => m.Command( + It.IsAny(), session, null, databaseNamespace, @@ -464,8 +465,7 @@ public void GetMore_should_use_same_session( null, CommandResponseHandling.Return, It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Callback(() => sameSessionWasUsed = true) .Returns(secondBatch); @@ -543,6 +543,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock c.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -553,8 +554,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .ReturnsAsync(() => { var bsonDocument = commandResultFunc(); @@ -570,6 +570,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock c.Command( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -580,8 +581,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Returns(() => { var bsonDocument = commandResultFunc(); @@ -596,6 +596,7 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock { mockChannelHandle.Verify( s => s.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -606,16 +607,14 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock It.IsAny>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny()), + It.IsAny()), times); - - } else { mockChannelHandle.Verify( s => s.Command( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -626,8 +625,7 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock It.IsAny>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny()), + It.IsAny()), times); } } @@ -694,6 +692,7 @@ private IReadOnlyList GetFirstBatchUsingFindCommand(IChannelHandle { "batchSize", batchSize } }; var result = channel.Command( + new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), _session, ReadPreference.Primary, _databaseNamespace, @@ -704,8 +703,7 @@ private IReadOnlyList GetFirstBatchUsingFindCommand(IChannelHandle null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); var cursor = result["cursor"].AsBsonDocument; var firstBatch = cursor["firstBatch"].AsBsonArray.Select(i => i.AsBsonDocument).ToList(); cursorId = cursor["id"].ToInt64(); diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs index 325eb5536f7..48b4bed12bf 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs @@ -15,7 +15,6 @@ using System.Collections.Generic; using System.Net; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; @@ -97,6 +96,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -107,14 +107,14 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -125,8 +125,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -151,6 +150,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -161,14 +161,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -179,8 +179,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -204,6 +203,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -214,14 +214,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -232,8 +232,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -258,6 +257,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -268,14 +268,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -286,8 +286,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs index 93b655b3e19..71c26aebdba 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs @@ -15,7 +15,6 @@ using System.Collections.Generic; using System.Net; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; @@ -71,6 +70,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -81,14 +81,14 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -99,8 +99,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -123,6 +122,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), subject.DatabaseNamespace, @@ -133,14 +133,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), It.IsAny(), It.IsAny(), subject.DatabaseNamespace, @@ -151,8 +151,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -175,6 +174,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -185,14 +185,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -203,8 +203,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs index 16709b65bf3..26c7a3b7eac 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs @@ -75,28 +75,6 @@ public LoadBalancedTests(ITestOutputHelper output) : base(output) _connectionId = new ConnectionId(_subject.ServerId); } - [Theory] - [ParameterAttributeData] - public async Task ChannelFork_should_not_affect_operations_count([Values(false, true)] bool async) - { - IClusterableServer server = SetupServer(false, false); - - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - - server.OutstandingOperationsCount.Should().Be(1); - - var forkedChannel = channel.Fork(); - server.OutstandingOperationsCount.Should().Be(1); - - forkedChannel.Dispose(); - server.OutstandingOperationsCount.Should().Be(1); - - channel.Dispose(); - server.OutstandingOperationsCount.Should().Be(0); - } - [Fact] public void Constructor_should_not_throw_when_serverApi_is_null() { @@ -248,20 +226,20 @@ public async Task GetChannel_should_set_operations_count_correctly( { IClusterableServer server = SetupServer(false, false); - var channels = new List(); + var channels = new List(); for (int i = 0; i < operationsCount; i++) { - var channel = async ? + var connection = async ? await server.GetChannelAsync(OperationContext.NoTimeout) : server.GetChannel(OperationContext.NoTimeout); - channels.Add(channel); + channels.Add(connection); } server.OutstandingOperationsCount.Should().Be(operationsCount); foreach (var channel in channels) { - channel.Dispose(); + server.DecrementOutstandingOperationsCount(); server.OutstandingOperationsCount.Should().Be(--operationsCount); } } @@ -305,8 +283,8 @@ public async Task GetChannel_should_not_update_topology_and_clear_connection_poo var openConnectionException = new MongoConnectionException(connectionId, "Oops", new IOException("Cry", innerMostException)); var mockConnection = new Mock(); mockConnection.Setup(c => c.ConnectionId).Returns(connectionId); - mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); - mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); + mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); + mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); var connectionFactory = new Mock(); connectionFactory.Setup(cf => cf.CreateConnection(serverId, _endPoint)).Returns(mockConnection.Object); diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs index 701fe502cdd..1fa4a5d2588 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs @@ -107,7 +107,7 @@ public void Round_trip_time_monitor_should_work_as_expected() }); mockConnection - .SetupSequence(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .SetupSequence(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns( () => { @@ -281,7 +281,7 @@ private ConnectionDescription CreateConnectionDescription() private RoundTripTimeMonitor CreateSubject(TimeSpan frequency, Mock mockConnection) { mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => CreateResponseMessage()); var mockConnectionFactory = new Mock(); diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs index cd564ab9773..39a1810e780 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs @@ -97,28 +97,6 @@ protected override void DisposeInternal() _subject.Dispose(); } - [Theory] - [ParameterAttributeData] - public async Task ChannelFork_should_not_affect_operations_count([Values(false, true)] bool async) - { - IClusterableServer server = SetupServer(false, false); - - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - - server.OutstandingOperationsCount.Should().Be(1); - - var forkedChannel = channel.Fork(); - server.OutstandingOperationsCount.Should().Be(1); - - forkedChannel.Dispose(); - server.OutstandingOperationsCount.Should().Be(1); - - channel.Dispose(); - server.OutstandingOperationsCount.Should().Be(0); - } - [Fact] public void Constructor_should_not_throw_when_serverApi_is_null() { @@ -251,11 +229,11 @@ public async Task GetChannel_should_get_a_connection( { _subject.Initialize(); - var channel = async ? + var connection = async ? await _subject.GetChannelAsync(OperationContext.NoTimeout) : _subject.GetChannel(OperationContext.NoTimeout); - channel.Should().NotBeNull(); + connection.Should().NotBeNull(); } [Theory] @@ -282,20 +260,20 @@ public async Task GetChannel_should_set_operations_count_correctly( { IClusterableServer server = SetupServer(false, false); - var channels = new List(); + var channels = new List(); for (int i = 0; i < operationsCount; i++) { - var channel = async ? + var connection = async ? await server.GetChannelAsync(OperationContext.NoTimeout) : server.GetChannel(OperationContext.NoTimeout); - channels.Add(channel); + channels.Add(connection); } server.OutstandingOperationsCount.Should().Be(operationsCount); foreach (var channel in channels) { - channel.Dispose(); + server.DecrementOutstandingOperationsCount(); server.OutstandingOperationsCount.Should().Be(--operationsCount); } } @@ -340,8 +318,8 @@ public async Task GetChannel_should_update_topology_and_clear_connection_pool_on var openConnectionException = new MongoConnectionException(connectionId, "Oops", new IOException("Cry", innerMostException)); var mockConnection = new Mock(); mockConnection.Setup(c => c.ConnectionId).Returns(connectionId); - mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); - mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); + mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); + mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); var connectionFactory = new Mock(); connectionFactory.Setup(f => f.ConnectionSettings).Returns(() => new ConnectionSettings()); @@ -853,7 +831,7 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t using (var cluster = CoreTestConfiguration.CreateCluster(b => b.Subscribe(eventCapturer))) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var server = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); using (var channel = server.GetChannel(OperationContext.NoTimeout)) { session.AdvanceClusterTime(sessionClusterTime); @@ -863,6 +841,7 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t try { channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -873,8 +852,7 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } catch (MongoCommandException ex) { @@ -900,11 +878,12 @@ public void Command_should_update_the_session_and_cluster_cluster_times() using (var cluster = CoreTestConfiguration.CreateCluster(b => b.Subscribe(eventCapturer))) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var server = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); using (var channel = server.GetChannel(OperationContext.NoTimeout)) { var command = BsonDocument.Parse("{ ping : 1 }"); channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -915,15 +894,14 @@ public void Command_should_update_the_session_and_cluster_cluster_times() null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); - } + new MessageEncoderSettings()); - var commandSucceededEvent = eventCapturer.Next().Should().BeOfType().Subject; - var actualReply = commandSucceededEvent.Reply; - var actualClusterTime = actualReply["$clusterTime"].AsBsonDocument; - session.ClusterTime.Should().Be(actualClusterTime); - server.ClusterClock.ClusterTime.Should().Be(actualClusterTime); + var commandSucceededEvent = eventCapturer.Next().Should().BeOfType().Subject; + var actualReply = commandSucceededEvent.Reply; + var actualClusterTime = actualReply["$clusterTime"].AsBsonDocument; + session.ClusterTime.Should().Be(actualClusterTime); + server.ClusterClock.ClusterTime.Should().Be(actualClusterTime); + } } } @@ -943,7 +921,7 @@ public async Task Command_should_use_serverApi([Values(false, true)] bool async) using (var cluster = CoreTestConfiguration.CreateCluster(builder)) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var server = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); using (var channel = server.GetChannel(OperationContext.NoTimeout)) { var command = BsonDocument.Parse("{ ping : 1 }"); @@ -951,6 +929,7 @@ public async Task Command_should_use_serverApi([Values(false, true)] bool async) { await channel .CommandAsync( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -961,12 +940,12 @@ await channel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } else { channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -977,8 +956,7 @@ await channel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } } } diff --git a/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs b/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs index 9eef407362b..2e989358f7e 100644 --- a/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs @@ -71,13 +71,14 @@ public void Execute_should_use_cached_IWireProtocol_if_available([Values(false, responseHandling, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); var commandResponse = MessageHelper.BuildCommandResponse(CreateRawBsonDocument(new BsonDocument("ok", 1))); var connectionId = SetupConnection(mockConnection); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); var cachedWireProtocol = subject._cachedWireProtocol(); cachedWireProtocol.Should().NotBeNull(); @@ -91,7 +92,7 @@ public void Execute_should_use_cached_IWireProtocol_if_available([Values(false, subject._responseHandling(CommandResponseHandling.Ignore); // will trigger the exception if the CommandUsingCommandMessageWireProtocol ctor will be called result = null; - var exception = Record.Exception(() => { result = subject.Execute(mockConnection.Object, CancellationToken.None); }); + var exception = Record.Exception(() => { result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); }); if (withSameConnection) { @@ -118,7 +119,7 @@ ConnectionId SetupConnection(Mock connection, ConnectionId id = nul } connection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessage(OperationContext.NoTimeout, It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(commandResponse); connection.SetupGet(c => c.ConnectionId).Returns(id); connection @@ -133,7 +134,7 @@ ConnectionId SetupConnection(Mock connection, ConnectionId id = nul [Theory] [ParameterAttributeData] - public void Execute_should_use_serverApi_with_getMoreCommand( + public async Task Execute_should_use_serverApi_with_getMoreCommand( [Values(false, true)] bool useServerApi, [Values(false, true)] bool async) { @@ -155,15 +156,16 @@ public void Execute_should_use_serverApi_with_getMoreCommand( CommandResponseHandling.Return, BsonDocumentSerializer.Instance, new MessageEncoderSettings(), - serverApi); + serverApi, + TimeSpan.FromMilliseconds(42)); if (async) { - subject.ExecuteAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); + await subject.ExecuteAsync(OperationContext.NoTimeout, connection); } else { - subject.Execute(connection, CancellationToken.None); + subject.Execute(OperationContext.NoTimeout, connection); } SpinWait.SpinUntil(() => connection.GetSentMessages().Count >= 1, TimeSpan.FromSeconds(4)).Should().BeTrue(); @@ -177,7 +179,7 @@ public void Execute_should_use_serverApi_with_getMoreCommand( [Theory] [ParameterAttributeData] - public void Execute_should_use_serverApi_in_transaction( + public async Task Execute_should_use_serverApi_in_transaction( [Values(false, true)] bool useServerApi, [Values(false, true)] bool async) { @@ -199,15 +201,16 @@ public void Execute_should_use_serverApi_in_transaction( CommandResponseHandling.Return, BsonDocumentSerializer.Instance, new MessageEncoderSettings(), - serverApi); + serverApi, + TimeSpan.FromMilliseconds(42)); if (async) { - subject.ExecuteAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); + await subject.ExecuteAsync(OperationContext.NoTimeout, connection); } else { - subject.Execute(connection, CancellationToken.None); + subject.Execute(OperationContext.NoTimeout, connection); } SpinWait.SpinUntil(() => connection.GetSentMessages().Count >= 1, TimeSpan.FromSeconds(4)).Should().BeTrue(); @@ -247,17 +250,18 @@ public void Execute_should_wait_for_response_when_CommandResponseHandling_is_Ret CommandResponseHandling.Return, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); var commandResponse = MessageHelper.BuildReply(CreateRawBsonDocument(new BsonDocument("ok", 1))); mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(commandResponse); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); result.Should().Be("{ok: 1}"); } @@ -277,21 +281,22 @@ public void Execute_should_not_wait_for_response_when_CommandResponseHandling_is CommandResponseHandling.NoResponseExpected, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); result.Should().BeNull(); mockConnection.Verify( - c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None), + c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings), Times.Once); } [Fact] - public void ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_is_Return() + public async Task ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_is_Return() { var messageEncoderSettings = new MessageEncoderSettings(); var subject = new CommandWireProtocol( @@ -306,22 +311,23 @@ public void ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_i CommandResponseHandling.Return, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); var commandResponse = MessageHelper.BuildReply(CreateRawBsonDocument(new BsonDocument("ok", 1))); mockConnection - .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(Task.FromResult(commandResponse)); - var result = subject.ExecuteAsync(mockConnection.Object, CancellationToken.None).GetAwaiter().GetResult(); + var result = await subject.ExecuteAsync(OperationContext.NoTimeout, mockConnection.Object); result.Should().Be("{ok: 1}"); } [Fact] - public void ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandling_is_NoResponseExpected() + public async Task ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandling_is_NoResponseExpected() { var messageEncoderSettings = new MessageEncoderSettings(); var subject = new CommandWireProtocol( @@ -336,15 +342,16 @@ public void ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandli CommandResponseHandling.NoResponseExpected, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); - var result = subject.ExecuteAsync(mockConnection.Object, CancellationToken.None).GetAwaiter().GetResult(); + var result = await subject.ExecuteAsync(OperationContext.NoTimeout, mockConnection.Object); result.Should().BeNull(); - mockConnection.Verify(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None), Times.Once); + mockConnection.Verify(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings), Times.Once); } // private methods diff --git a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs index fb09f34e818..01ed41022a4 100644 --- a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs +++ b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs @@ -136,9 +136,9 @@ public async Task CreateEncryptedCollection_should_handle_generated_key_when_sec mockCluster.SetupGet(c => c.Description).Returns(clusterDescription); var mockServer = new Mock(); mockServer.SetupGet(s => s.Description).Returns(serverDescription); - var channel = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); - mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(channel); - mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(channel); + var connection = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); + mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(connection); + mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(connection); mockCluster .Setup(m => m.SelectServer(It.IsAny(), It.IsAny())) @@ -225,9 +225,9 @@ public async Task CreateEncryptedCollection_should_handle_various_encryptedField mockCluster.SetupGet(c => c.Description).Returns(clusterDescription); var mockServer = new Mock(); mockServer.SetupGet(s => s.Description).Returns(serverDescription); - var channel = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); - mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(channel); - mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(channel); + var connection = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); + mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(connection); + mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(connection); mockCluster .Setup(m => m.SelectServer(It.IsAny(), It.IsAny())) diff --git a/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs b/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs index 85fde463079..f2bea86436c 100644 --- a/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs @@ -27,8 +27,6 @@ namespace MongoDB.Driver.Tests internal interface IJsonDrivenTestRunner { IClusterInternal FailPointCluster { get; } - IServer FailPointServer { get; } - void ConfigureFailPoint(IServer server, ICoreSessionHandle session, BsonDocument failCommand); Task ConfigureFailPointAsync(IServer server, ICoreSessionHandle session, BsonDocument failCommand); } @@ -49,8 +47,6 @@ public IClusterInternal FailPointCluster } } - public IServer FailPointServer => null; - public void ConfigureFailPoint(IServer server, ICoreSessionHandle session, BsonDocument failCommand) { var failPoint = FailPoint.Configure(server, session, failCommand, withAsync: false); diff --git a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs index 389449d592d..cda1386f787 100644 --- a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs +++ b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs @@ -46,22 +46,12 @@ protected override async Task CallMethodAsync(CancellationToken cancellationToke protected virtual IServer GetFailPointServer() { - if (TestRunner.FailPointServer != null) - { - return TestRunner.FailPointServer; - } - var cluster = TestRunner.FailPointCluster; return cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); } protected async virtual Task GetFailPointServerAsync() { - if (TestRunner.FailPointServer != null) - { - return TestRunner.FailPointServer; - } - var cluster = TestRunner.FailPointCluster; return await cluster.SelectServerAsync(OperationContext.NoTimeout, WritableServerSelector.Instance).ConfigureAwait(false); } diff --git a/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs b/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs index cad223d7c2d..6da8fcb73af 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs @@ -83,12 +83,6 @@ private static class Schema public readonly static string ignore = nameof(ignore); public readonly static string async = nameof(async); - public static class Operations - { - public const string runOn = nameof(runOn); - public readonly static string failPoint = nameof(failPoint); - } - public static class Intergration { public readonly static string runOn = nameof(runOn); @@ -101,12 +95,6 @@ public static class Styles public readonly static string integration = nameof(integration); } - public sealed class FailPoint - { - public readonly static string appName = nameof(appName); - public readonly static string data = nameof(data); - } - public readonly static string[] AllFields = new[] { _path, @@ -745,33 +733,6 @@ o is ServerHeartbeatSucceededEvent || return (connectionPool, failPoint, cluster, eventsFilter); } - private IConnectionPool SetupConnectionPoolMock(BsonDocument test, IEventSubscriber eventSubscriber) - { - var endPoint = new DnsEndPoint("localhost", 27017); - var serverId = new ServerId(new ClusterId(), endPoint); - ParseSettings(test, out var connectionPoolSettings, out var connectionSettings); - - var connectionFactory = new Mock(); - var exceptionHandler = new Mock(); - connectionFactory.Setup(f => f.ConnectionSettings).Returns(() => new ConnectionSettings()); - connectionFactory - .Setup(c => c.CreateConnection(serverId, endPoint)) - .Returns(() => - { - var connection = new MockConnection(serverId, connectionSettings, eventSubscriber); - return connection; - }); - var connectionPool = new ExclusiveConnectionPool( - serverId, - endPoint, - connectionPoolSettings, - connectionFactory.Object, - exceptionHandler.Object, - eventSubscriber.ToEventLogger()); - - return connectionPool; - } - private void Start(BsonDocument operation, ConcurrentDictionary tasks) { var startTarget = operation.GetValue("target").ToString(); @@ -855,6 +816,14 @@ public static ServerId ServerId(this object @event) internal static class IServerReflector { - public static IConnectionPool _connectionPool(this IServer server) => (IConnectionPool)Reflector.GetFieldValue(server, nameof(_connectionPool)); + public static IConnectionPool _connectionPool(this IServer server) + { + if (server is SelectedServer) + { + server = (IServer)Reflector.GetFieldValue(server, "_server"); + } + + return (IConnectionPool)Reflector.GetFieldValue(server, nameof(_connectionPool)); + } } } diff --git a/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs index 8117f67bcc5..960d24bd773 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs @@ -79,11 +79,11 @@ public async Task DriverAcceptsArbitraryAuthMechanism([Values(false, true)] bool if (async) { - await subject.OpenAsync(CancellationToken.None); + await subject.OpenAsync(OperationContext.NoTimeout); } else { - subject.Open(CancellationToken.None); + subject.Open(OperationContext.NoTimeout); } subject._state().Should().Be(3); // 3 - open. diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringTestRunner.cs b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringTestRunner.cs index bb9b59e94e7..9fc61009e7b 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringTestRunner.cs @@ -22,6 +22,7 @@ using MongoDB.Bson.TestHelpers; using MongoDB.Bson.TestHelpers.JsonDrivenTests; using MongoDB.Driver.Core; +using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Configuration; using MongoDB.Driver.Core.ConnectionPools; @@ -617,6 +618,11 @@ public static IConnectionPool _connectionPool(this Server server) public static IServerMonitor _monitor(this IServer server) { + if (server is SelectedServer) + { + server = (IServer)Reflector.GetFieldValue(server, "_server"); + } + return (IServerMonitor)Reflector.GetFieldValue(server, nameof(_monitor)); } @@ -625,7 +631,7 @@ public static void HandleBeforeHandshakeCompletesException(this Server server, E Reflector.Invoke(server, nameof(HandleBeforeHandshakeCompletesException), ex); } - public static void HandleChannelException(this Server server, IConnection connection, Exception ex) + public static void HandleChannelException(this Server server, IConnectionHandle connection, Exception ex) { Reflector.Invoke(server, nameof(HandleChannelException), connection, ex, checkBaseClass: true); }