diff --git a/driver-core/src/main/com/mongodb/internal/ExceptionUtils.java b/driver-core/src/main/com/mongodb/internal/ExceptionUtils.java index 96083f66833..9ccb5ef0c8b 100644 --- a/driver-core/src/main/com/mongodb/internal/ExceptionUtils.java +++ b/driver-core/src/main/com/mongodb/internal/ExceptionUtils.java @@ -17,6 +17,8 @@ package com.mongodb.internal; import com.mongodb.MongoCommandException; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -35,6 +37,15 @@ *

This class is not part of the public API and may be removed or changed at any time

*/ public final class ExceptionUtils { + + public static boolean isMongoSocketException(final Throwable e) { + return e instanceof MongoSocketException; + } + + public static boolean isOperationTimeoutFromSocketException(final Throwable e) { + return e instanceof MongoOperationTimeoutException && e.getCause() instanceof MongoSocketException; + } + public static final class MongoCommandExceptionUtils { public static int extractErrorCode(final BsonDocument response) { return extractErrorCodeAsBson(response).intValue(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/AbstractProtocolExecutor.java b/driver-core/src/main/com/mongodb/internal/connection/AbstractProtocolExecutor.java new file mode 100644 index 00000000000..ba200933860 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/AbstractProtocolExecutor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.internal.session.SessionContext; + +import static com.mongodb.internal.ExceptionUtils.isMongoSocketException; +import static com.mongodb.internal.ExceptionUtils.isOperationTimeoutFromSocketException; + +/** + *

This class is not part of the public API and may be removed or changed at any time

+ */ +public abstract class AbstractProtocolExecutor implements ProtocolExecutor { + + protected boolean shouldMarkSessionDirty(final Throwable e, final SessionContext sessionContext) { + if (!sessionContext.hasSession()) { + return false; + } + return isMongoSocketException(e) || isOperationTimeoutFromSocketException(e); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java index 458507ee11c..8f3d0f09fd9 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java @@ -18,7 +18,6 @@ import com.mongodb.MongoException; import com.mongodb.MongoServerUnavailableException; -import com.mongodb.MongoSocketException; import com.mongodb.ReadPreference; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; @@ -197,7 +196,7 @@ ServerId serverId() { return serverId; } - private class DefaultServerProtocolExecutor implements ProtocolExecutor { + private class DefaultServerProtocolExecutor extends AbstractProtocolExecutor { @SuppressWarnings("unchecked") @Override @@ -216,9 +215,9 @@ public T execute(final CommandProtocol protocol, final InternalConnection if (e instanceof MongoWriteConcernWithResponseException) { return (T) ((MongoWriteConcernWithResponseException) e).getResponse(); } else { - if (e instanceof MongoSocketException && sessionContext.hasSession()) { + if (shouldMarkSessionDirty(e, sessionContext)) { sessionContext.markSessionDirty(); - } + } throw e; } } @@ -239,7 +238,7 @@ public void executeAsync(final CommandProtocol protocol, final InternalCo if (t instanceof MongoWriteConcernWithResponseException) { callback.onResult((T) ((MongoWriteConcernWithResponseException) t).getResponse(), null); } else { - if (t instanceof MongoSocketException && sessionContext.hasSession()) { + if (shouldMarkSessionDirty(t, sessionContext)) { sessionContext.markSessionDirty(); } callback.onResult(null, t); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index bdc77ad72a6..a0eeb39d31d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -76,6 +76,7 @@ import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static com.mongodb.internal.TimeoutContext.createMongoTimeoutException; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate; import static com.mongodb.internal.connection.CommandHelper.HELLO; @@ -775,7 +776,7 @@ private void throwTranslatedWriteException(final Throwable e, final OperationCon private MongoException translateWriteException(final Throwable e, final OperationContext operationContext) { if (e instanceof MongoSocketWriteTimeoutException && operationContext.getTimeoutContext().hasExpired()) { - return TimeoutContext.createMongoTimeoutException(e); + return createMongoTimeoutException(e); } if (e instanceof MongoException) { @@ -792,9 +793,12 @@ private MongoException translateWriteException(final Throwable e, final Operatio } private MongoException translateReadException(final Throwable e, final OperationContext operationContext) { - if (operationContext.getTimeoutContext().hasTimeoutMS() - && (e instanceof SocketTimeoutException || e instanceof MongoSocketReadTimeoutException)) { - return TimeoutContext.createMongoTimeoutException(e); + if (operationContext.getTimeoutContext().hasTimeoutMS()) { + if (e instanceof SocketTimeoutException) { + return createMongoTimeoutException(createReadTimeoutException((SocketTimeoutException) e)); + } else if (e instanceof MongoSocketReadTimeoutException) { + return createMongoTimeoutException((e)); + } } if (e instanceof MongoException) { @@ -804,7 +808,7 @@ private MongoException translateReadException(final Throwable e, final Operation if (interruptedException.isPresent()) { return interruptedException.get(); } else if (e instanceof SocketTimeoutException) { - return new MongoSocketReadTimeoutException("Timeout while receiving message", getServerAddress(), e); + return createReadTimeoutException((SocketTimeoutException) e); } else if (e instanceof IOException) { return new MongoSocketReadException("Exception receiving message", getServerAddress(), e); } else if (e instanceof RuntimeException) { @@ -814,6 +818,11 @@ private MongoException translateReadException(final Throwable e, final Operation } } + private MongoSocketReadTimeoutException createReadTimeoutException(final SocketTimeoutException e) { + return new MongoSocketReadTimeoutException("Timeout while receiving message", + getServerAddress(), e); + } + private ResponseBuffers receiveResponseBuffers(final OperationContext operationContext) { try { ByteBuf messageHeaderBuffer = stream.read(MESSAGE_HEADER_LENGTH, operationContext); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index 45122331645..ee509873e40 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -50,6 +50,7 @@ *

This class is not part of the public API and may be removed or changed at any time

*/ public class InternalStreamConnectionInitializer implements InternalConnectionInitializer { + private static final int INITIAL_MIN_RTT = 0; private final ClusterConnectionMode clusterConnectionMode; private final Authenticator authenticator; private final BsonDocument clientMetadataDocument; @@ -160,7 +161,7 @@ private InternalConnectionInitializationDescription createInitializationDescript helloResult); ServerDescription serverDescription = createServerDescription(internalConnection.getDescription().getServerAddress(), helloResult, - System.nanoTime() - startTime, 0); + System.nanoTime() - startTime, INITIAL_MIN_RTT); return new InternalConnectionInitializationDescription(connectionDescription, serverDescription); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java index 026cc5a61a1..3820810ab9f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java @@ -154,7 +154,7 @@ ConnectionPool getConnectionPool() { return connectionPool; } - private class LoadBalancedServerProtocolExecutor implements ProtocolExecutor { + private class LoadBalancedServerProtocolExecutor extends AbstractProtocolExecutor { @SuppressWarnings("unchecked") @Override public T execute(final CommandProtocol protocol, final InternalConnection connection, final SessionContext sessionContext) { @@ -191,7 +191,7 @@ public void executeAsync(final CommandProtocol protocol, final InternalCo private void handleExecutionException(final InternalConnection connection, final SessionContext sessionContext, final Throwable t) { invalidate(t, connection.getDescription().getServiceId(), connection.getGeneration()); - if (t instanceof MongoSocketException && sessionContext.hasSession()) { + if (shouldMarkSessionDirty(t, sessionContext)) { sessionContext.markSessionDirty(); } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java index 4bbde3ff036..eec8721fbf1 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandBatchCursor.java @@ -18,6 +18,7 @@ import com.mongodb.MongoCommandException; import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; import com.mongodb.MongoSocketException; import com.mongodb.ReadPreference; import com.mongodb.ServerAddress; @@ -286,8 +287,8 @@ void executeWithConnection(final AsyncCallableConnectionWithCallback call return; } callable.call(assertNotNull(connection), (result, t1) -> { - if (t1 instanceof MongoSocketException) { - onCorruptedConnection(connection, (MongoSocketException) t1); + if (t1 != null) { + handleException(connection, t1); } connection.release(); callback.onResult(result, t1); @@ -295,6 +296,14 @@ void executeWithConnection(final AsyncCallableConnectionWithCallback call }); } + private void handleException(final AsyncConnection connection, final Throwable exception) { + if (exception instanceof MongoOperationTimeoutException && exception.getCause() instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) exception.getCause()); + } else if (exception instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) exception); + } + } + private void getConnection(final SingleResultCallback callback) { assertTrue(getState() != State.IDLE); AsyncConnection pinnedConnection = getPinnedConnection(); diff --git a/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java index a2bb4fdb8c7..410098db2c0 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CommandBatchCursor.java @@ -19,6 +19,7 @@ import com.mongodb.MongoCommandException; import com.mongodb.MongoException; import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; import com.mongodb.MongoSocketException; import com.mongodb.ReadPreference; import com.mongodb.ServerAddress; @@ -334,6 +335,12 @@ void executeWithConnection(final Consumer action) { } catch (MongoSocketException e) { onCorruptedConnection(connection, e); throw e; + } catch (MongoOperationTimeoutException e) { + Throwable cause = e.getCause(); + if (cause instanceof MongoSocketException) { + onCorruptedConnection(connection, (MongoSocketException) cause); + } + throw e; } finally { connection.release(); } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java new file mode 100644 index 00000000000..53b2d78eae2 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandBatchCursorTest.java @@ -0,0 +1,202 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ServerAddress; +import com.mongodb.client.cursor.TimeoutMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.binding.AsyncConnectionSource; +import com.mongodb.internal.connection.AsyncConnection; +import com.mongodb.internal.connection.OperationContext; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.codecs.Decoder; +import org.bson.codecs.DocumentCodec; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class AsyncCommandBatchCursorTest { + + private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); + private static final BsonInt64 CURSOR_ID = new BsonInt64(1); + private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) + .append("cursor", + new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) + .append("id", CURSOR_ID) + .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); + + private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); + + + private AsyncConnection mockConnection; + private ConnectionDescription mockDescription; + private AsyncConnectionSource connectionSource; + private OperationContext operationContext; + private TimeoutContext timeoutContext; + private ServerDescription serverDescription; + + @BeforeEach + void setUp() { + ServerVersion serverVersion = new ServerVersion(3, 6); + + mockConnection = mock(AsyncConnection.class, "connection"); + mockDescription = mock(ConnectionDescription.class); + when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); + when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); + when(mockConnection.getDescription()).thenReturn(mockDescription); + when(mockConnection.retain()).thenReturn(mockConnection); + + connectionSource = mock(AsyncConnectionSource.class); + operationContext = mock(OperationContext.class); + timeoutContext = mock(TimeoutContext.class); + serverDescription = mock(ServerDescription.class); + when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); + when(connectionSource.getOperationContext()).thenReturn(operationContext); + doAnswer(invocation -> { + SingleResultCallback callback = invocation.getArgument(0); + callback.onResult(mockConnection, null); + return null; + }).when(connectionSource).getConnection(any()); + when(connectionSource.getServerDescription()).thenReturn(serverDescription); + } + + + @Test + void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoSocketException("test", new ServerAddress())); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next((result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoSocketException.class, t.getClass()); + }); + + //then + commandBatchCursor.close(); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + } + + + @Test + void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoOperationTimeoutException("test")); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + when(timeoutContext.hasTimeoutMS()).thenReturn(true); + + AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next((result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + }); + + commandBatchCursor.close(); + + + //then + verify(mockConnection, times(2)).commandAsync(any(), + any(), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); + } + + @Test + void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { + //given + doAnswer(invocation -> { + SingleResultCallback argument = invocation.getArgument(6); + argument.onResult(null, new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any(), any()); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + when(timeoutContext.hasTimeoutMS()).thenReturn(true); + + AsyncCommandBatchCursor commandBatchCursor = createBatchCursor(); + + //when + commandBatchCursor.next((result, t) -> { + Assertions.assertNull(result); + Assertions.assertNotNull(t); + Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + }); + + commandBatchCursor.close(); + + //then + verify(mockConnection, times(1)).commandAsync(any(), + any(), any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any(), any()); + verify(mockConnection, never()).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any(), any()); + } + + + private AsyncCommandBatchCursor createBatchCursor() { + return new AsyncCommandBatchCursor( + TimeoutMode.CURSOR_LIFETIME, + COMMAND_CURSOR_DOCUMENT, + 0, + 0, + DOCUMENT_CODEC, + null, + connectionSource, + mockConnection); + } + +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionSpecification.groovy index 036a3632316..7a0dca34526 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionSpecification.groovy @@ -20,9 +20,11 @@ import com.mongodb.MongoCommandException import com.mongodb.MongoInternalException import com.mongodb.MongoInterruptedException import com.mongodb.MongoNamespace +import com.mongodb.MongoOperationTimeoutException import com.mongodb.MongoSocketClosedException import com.mongodb.MongoSocketException import com.mongodb.MongoSocketReadException +import com.mongodb.MongoSocketReadTimeoutException import com.mongodb.MongoSocketWriteException import com.mongodb.ReadConcern import com.mongodb.ServerAddress @@ -38,6 +40,7 @@ import com.mongodb.event.CommandFailedEvent import com.mongodb.event.CommandStartedEvent import com.mongodb.event.CommandSucceededEvent import com.mongodb.internal.ExceptionUtils.MongoCommandExceptionUtils +import com.mongodb.internal.TimeoutContext import com.mongodb.internal.session.SessionContext import com.mongodb.internal.validator.NoOpFieldNameValidator import org.bson.BsonDocument @@ -57,6 +60,7 @@ import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import static com.mongodb.ClusterFixture.OPERATION_CONTEXT +import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT import static com.mongodb.ReadPreference.primary import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE import static com.mongodb.connection.ClusterConnectionMode.SINGLE @@ -455,6 +459,88 @@ class InternalStreamConnectionSpecification extends Specification { connection.isClosed() } + def 'Should throw timeout exception with underlying socket exception as a cause when Stream.read throws SocketException'() { + given: + stream.read(_, _) >> { throw new SocketTimeoutException() } + def connection = getOpenedConnection() + + when: + connection.receiveMessage(1, OPERATION_CONTEXT.withTimeoutContext( + new TimeoutContext(TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT))) + + then: + def timeoutException = thrown(MongoOperationTimeoutException) + def mongoSocketReadTimeoutException = timeoutException.getCause() + mongoSocketReadTimeoutException instanceof MongoSocketReadTimeoutException + mongoSocketReadTimeoutException.getCause() instanceof SocketTimeoutException + + connection.isClosed() + } + + def 'Should wrap MongoSocketReadTimeoutException with MongoOperationTimeoutException'() { + given: + stream.read(_, _) >> { throw new MongoSocketReadTimeoutException("test", new ServerAddress(), null) } + def connection = getOpenedConnection() + + when: + connection.receiveMessage(1, OPERATION_CONTEXT.withTimeoutContext( + new TimeoutContext(TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT))) + + then: + def timeoutException = thrown(MongoOperationTimeoutException) + def mongoSocketReadTimeoutException = timeoutException.getCause() + mongoSocketReadTimeoutException instanceof MongoSocketReadTimeoutException + mongoSocketReadTimeoutException.getCause() == null + + connection.isClosed() + } + + + def 'Should wrap SocketException with timeout exception when Stream.read throws SocketException async'() { + given: + stream.readAsync(_ , _, _) >> { numBytes, operationContext, handler -> + handler.failed(new SocketTimeoutException()) + } + def connection = getOpenedConnection() + def callback = new FutureResultCallback() + def operationContext = OPERATION_CONTEXT.withTimeoutContext( + new TimeoutContext(TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT)) + when: + connection.receiveMessageAsync(1, operationContext, callback) + callback.get() + + then: + def timeoutException = thrown(MongoOperationTimeoutException) + def mongoSocketReadTimeoutException = timeoutException.getCause() + mongoSocketReadTimeoutException instanceof MongoSocketReadTimeoutException + mongoSocketReadTimeoutException.getCause() instanceof SocketTimeoutException + + connection.isClosed() + } + + def 'Should wrap MongoSocketReadTimeoutException with MongoOperationTimeoutException async'() { + given: + stream.readAsync(_, _, _) >> { numBytes, operationContext, handler -> + handler.failed(new MongoSocketReadTimeoutException("test", new ServerAddress(), null)) + } + + def connection = getOpenedConnection() + def callback = new FutureResultCallback() + def operationContext = OPERATION_CONTEXT.withTimeoutContext( + new TimeoutContext(TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT)) + when: + connection.receiveMessageAsync(1, operationContext, callback) + callback.get() + + then: + def timeoutException = thrown(MongoOperationTimeoutException) + def mongoSocketReadTimeoutException = timeoutException.getCause() + mongoSocketReadTimeoutException instanceof MongoSocketReadTimeoutException + mongoSocketReadTimeoutException.getCause() == null + + connection.isClosed() + } + def 'should close the stream when reading the message header throws an exception asynchronously'() { given: int seen = 0 diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java new file mode 100644 index 00000000000..3380785bd70 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/CommandBatchCursorTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.operation; + + +import com.mongodb.MongoNamespace; +import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketException; +import com.mongodb.ServerAddress; +import com.mongodb.client.cursor.TimeoutMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerType; +import com.mongodb.connection.ServerVersion; +import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.binding.ConnectionSource; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.OperationContext; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.codecs.Decoder; +import org.bson.codecs.DocumentCodec; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class CommandBatchCursorTest { + + private static final MongoNamespace NAMESPACE = new MongoNamespace("test", "test"); + private static final BsonInt64 CURSOR_ID = new BsonInt64(1); + private static final BsonDocument COMMAND_CURSOR_DOCUMENT = new BsonDocument("ok", new BsonInt32(1)) + .append("cursor", + new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) + .append("id", CURSOR_ID) + .append("firstBatch", new BsonArrayWrapper<>(new BsonArray()))); + + private static final Decoder DOCUMENT_CODEC = new DocumentCodec(); + + + private Connection mockConnection; + private ConnectionDescription mockDescription; + private ConnectionSource connectionSource; + private OperationContext operationContext; + private TimeoutContext timeoutContext; + private ServerDescription serverDescription; + + @BeforeEach + void setUp() { + ServerVersion serverVersion = new ServerVersion(3, 6); + + mockConnection = mock(Connection.class, "connection"); + mockDescription = mock(ConnectionDescription.class); + when(mockDescription.getMaxWireVersion()).thenReturn(getMaxWireVersionForServerVersion(serverVersion.getVersionList())); + when(mockDescription.getServerType()).thenReturn(ServerType.LOAD_BALANCER); + when(mockConnection.getDescription()).thenReturn(mockDescription); + when(mockConnection.retain()).thenReturn(mockConnection); + + connectionSource = mock(ConnectionSource.class); + operationContext = mock(OperationContext.class); + timeoutContext = mock(TimeoutContext.class); + serverDescription = mock(ServerDescription.class); + when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); + when(connectionSource.getOperationContext()).thenReturn(operationContext); + when(connectionSource.getConnection()).thenReturn(mockConnection); + when(connectionSource.getServerDescription()).thenReturn(serverDescription); + } + + + @Test + void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoSocketException("test", new ServerAddress())); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + + CommandBatchCursor commandBatchCursor = createBatchCursor(); + //when + Assertions.assertThrows(MongoSocketException.class, commandBatchCursor::next); + + //then + commandBatchCursor.close(); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any()); + } + + private CommandBatchCursor createBatchCursor() { + return new CommandBatchCursor<>( + TimeoutMode.CURSOR_LIFETIME, + COMMAND_CURSOR_DOCUMENT, + 0, + 0, + DOCUMENT_CODEC, + null, + connectionSource, + mockConnection); + } + + @Test + void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkErrorCause() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoOperationTimeoutException("test")); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + when(timeoutContext.hasTimeoutMS()).thenReturn(true); + + CommandBatchCursor commandBatchCursor = createBatchCursor(); + + //when + Assertions.assertThrows(MongoOperationTimeoutException.class, commandBatchCursor::next); + + commandBatchCursor.close(); + + + //then + verify(mockConnection, times(2)).command(any(), + any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + } + + @Test + void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { + //given + when(mockConnection.command(eq(NAMESPACE.getDatabaseName()), any(), any(), any(), any(), any())).thenThrow( + new MongoOperationTimeoutException("test", new MongoSocketException("test", new ServerAddress()))); + when(serverDescription.getType()).thenReturn(ServerType.LOAD_BALANCER); + when(timeoutContext.hasTimeoutMS()).thenReturn(true); + + CommandBatchCursor commandBatchCursor = createBatchCursor(); + + //when + Assertions.assertThrows(MongoOperationTimeoutException.class, commandBatchCursor::next); + commandBatchCursor.close(); + + //then + verify(mockConnection, times(1)).command(any(), + any(), any(), any(), any(), any()); + verify(mockConnection, times(1)).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("getMore")), any(), any(), any(), any()); + verify(mockConnection, never()).command(eq(NAMESPACE.getDatabaseName()), + argThat(bsonDocument -> bsonDocument.containsKey("killCursors")), any(), any(), any(), any()); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java index c2bc6c59411..119ad7f4470 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractClientSideOperationsTimeoutProseTest.java @@ -24,6 +24,7 @@ import com.mongodb.MongoCredential; import com.mongodb.MongoNamespace; import com.mongodb.MongoOperationTimeoutException; +import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.MongoTimeoutException; import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; @@ -746,6 +747,102 @@ public void shouldIgnoreWtimeoutMsOfWriteConcernToInitialAndSubsequentCommitTran }}); } + + /** + * Not a prose spec test. However, it is additional test case for better coverage. + */ + @Tag("setsFailPoint") + @DisplayName("KillCursors is not executed after getMore network error when timeout is not enabled") + @Test + public void testKillCursorsIsNotExecutedAfterGetMoreNetworkErrorWhenTimeoutIsNotEnabled() { + assumeTrue(serverVersionAtLeast(4, 4)); + assumeTrue(isServerlessTest()); + + long rtt = ClusterFixture.getPrimaryRTT(); + collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); + collectionHelper.insertDocuments(new Document(), new Document()); + collectionHelper.runAdminCommand("{" + + " configureFailPoint: \"failCommand\"," + + " mode: { times: 1}," + + " data: {" + + " failCommands: [\"getMore\" ]," + + " blockConnection: true," + + " blockTimeMS: " + (rtt + 600) + + " }" + + "}"); + + try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder() + .retryReads(true) + .applyToSocketSettings(builder -> builder.readTimeout(500, TimeUnit.MILLISECONDS)))) { + + MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); + + MongoCursor cursor = collection.find() + .batchSize(1) + .cursor(); + + cursor.next(); + assertThrows(MongoSocketReadTimeoutException.class, cursor::next); + cursor.close(); + } + + List events = commandListener.getCommandStartedEvents(); + assertEquals(2, events.size(), "Actual events: " + events.stream() + .map(CommandStartedEvent::getCommandName) + .collect(Collectors.toList())); + assertEquals(1, events.stream().filter(e -> e.getCommandName().equals("find")).count()); + assertEquals(1, events.stream().filter(e -> e.getCommandName().equals("getMore")).count()); + + } + + /** + * Not a prose spec test. However, it is additional test case for better coverage. + */ + @Tag("setsFailPoint") + @DisplayName("KillCursors is not executed after getMore network error") + @Test + public void testKillCursorsIsNotExecutedAfterGetMoreNetworkError() { + assumeTrue(serverVersionAtLeast(4, 4)); + assumeTrue(isServerlessTest()); + + long rtt = ClusterFixture.getPrimaryRTT(); + collectionHelper.create(namespace.getCollectionName(), new CreateCollectionOptions()); + collectionHelper.insertDocuments(new Document(), new Document()); + collectionHelper.runAdminCommand("{" + + " configureFailPoint: \"failCommand\"," + + " mode: { times: 1}," + + " data: {" + + " failCommands: [\"getMore\" ]," + + " blockConnection: true," + + " blockTimeMS: " + (rtt + 600) + + " }" + + "}"); + + try (MongoClient mongoClient = createMongoClient(getMongoClientSettingsBuilder() + .timeout(500, TimeUnit.MILLISECONDS))) { + + MongoCollection collection = mongoClient.getDatabase(namespace.getDatabaseName()) + .getCollection(namespace.getCollectionName()).withReadPreference(ReadPreference.primary()); + + MongoCursor cursor = collection.find() + .batchSize(1) + .cursor(); + + cursor.next(); + assertThrows(MongoOperationTimeoutException.class, cursor::next); + cursor.close(); + } + + List events = commandListener.getCommandStartedEvents(); + assertEquals(2, events.size(), "Actual events: " + events.stream() + .map(CommandStartedEvent::getCommandName) + .collect(Collectors.toList())); + assertEquals(1, events.stream().filter(e -> e.getCommandName().equals("find")).count()); + assertEquals(1, events.stream().filter(e -> e.getCommandName().equals("getMore")).count()); + + } + private static Stream test8ServerSelectionArguments() { return Stream.of( Arguments.of(Named.of("serverSelectionTimeoutMS honored if timeoutMS is not set", @@ -802,6 +899,7 @@ public void tearDown(final TestInfo info) { collectionHelper.drop(); filesCollectionHelper.drop(); chunksCollectionHelper.drop(); + commandListener.reset(); try { ServerHelper.checkPool(getPrimary()); } catch (InterruptedException e) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/ClientSideOperationTimeoutTest.java b/driver-sync/src/test/functional/com/mongodb/client/ClientSideOperationTimeoutTest.java index 77caabfdbbb..734377f41d1 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/ClientSideOperationTimeoutTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/ClientSideOperationTimeoutTest.java @@ -16,6 +16,7 @@ package com.mongodb.client; +import com.mongodb.ClusterFixture; import com.mongodb.client.unified.UnifiedSyncTest; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -71,6 +72,16 @@ public static boolean racyTestAssertion(final String testDescription, final Asse } public static void checkSkipCSOTTest(final String fileDescription, final String testDescription) { + if (ClusterFixture.isServerlessTest()) { + + // It is not possible to create capped collections on serverless instances. + assumeFalse(fileDescription.equals("timeoutMS behaves correctly for tailable awaitData cursors")); + assumeFalse(fileDescription.equals("timeoutMS behaves correctly for tailable non-awaitData cursors")); + + /* Drivers MUST NOT execute a killCursors command because the pinned connection is no longer under a load balancer. */ + assumeFalse(testDescription.equals("timeoutMS is refreshed for close")); + } + assumeFalse("No maxTimeMS parameter for createIndex() method", testDescription.contains("maxTimeMS is ignored if timeoutMS is set - createIndex on collection")); assumeFalse("TODO (CSOT) CRUD Failure",