Skip to content

Commit 0ef564d

Browse files
committed
Add SocketChannel verifications.
JAVA-5856
1 parent b6244f4 commit 0ef564d

File tree

1 file changed

+64
-14
lines changed

1 file changed

+64
-14
lines changed

driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,27 @@
2626
import org.junit.jupiter.api.BeforeEach;
2727
import org.junit.jupiter.params.ParameterizedTest;
2828
import org.junit.jupiter.params.provider.ValueSource;
29+
import org.mockito.MockedStatic;
30+
import org.mockito.Mockito;
31+
import org.mockito.invocation.InvocationOnMock;
32+
import org.mockito.stubbing.Answer;
2933

3034
import java.io.IOException;
3135
import java.net.ServerSocket;
3236
import java.nio.channels.InterruptedByTimeoutException;
37+
import java.nio.channels.SocketChannel;
3338
import java.util.concurrent.TimeUnit;
3439

35-
import static com.mongodb.assertions.Assertions.assertTrue;
40+
import static java.lang.String.format;
41+
import static java.util.concurrent.TimeUnit.SECONDS;
3642
import static org.junit.Assert.assertThrows;
3743
import static org.junit.jupiter.api.Assertions.assertFalse;
3844
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
45+
import static org.junit.jupiter.api.Assertions.assertNotNull;
46+
import static org.junit.jupiter.api.Assertions.assertTrue;
47+
import static org.junit.jupiter.api.Assertions.fail;
48+
import static org.mockito.Mockito.atLeast;
49+
import static org.mockito.Mockito.verify;
3950

4051
class TlsChannelStreamFunctionalTest {
4152
private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
@@ -50,16 +61,22 @@ void setUp() throws IOException {
5061
}
5162

5263
@AfterEach
64+
@SuppressWarnings("try")
5365
void cleanUp() throws IOException {
54-
try (ServerSocket ignore = serverSocket) {
66+
try (ServerSocket ignored = serverSocket) {
67+
//ignored
5568
}
5669
}
5770

5871
@ParameterizedTest
5972
@ValueSource(ints = {500, 1000, 2000})
60-
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeout) {
73+
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeout) throws IOException {
6174
//given
62-
try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) {
75+
try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
76+
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
77+
SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
78+
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
79+
6380
StreamFactory streamFactory = factory.create(SocketSettings.builder()
6481
.connectTimeout(connectTimeout, TimeUnit.MILLISECONDS)
6582
.build(), SSL_SETTINGS);
@@ -75,39 +92,72 @@ void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final in
7592

7693
//then
7794
long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - connectOpenStart);
78-
long diff = elapsedMs - connectTimeout;
79-
// Allowed difference, with test overhead setup is 300MS.
80-
int epsilonMs = 300;
95+
// Allow for some timing imprecision due to test overhead.
96+
int maximumAcceptableTimeoutOvershoot = 300;
8197

8298
assertInstanceOf(InterruptedByTimeoutException.class, mongoSocketOpenException.getCause(),
8399
"Actual cause: " + mongoSocketOpenException.getCause());
84-
assertFalse(diff < 0,
85-
String.format("Connection timed-out sooner than expected. Difference: %d ms", diff));
86-
assertTrue(diff < epsilonMs,
87-
String.format("Elapsed time %d ms should be within %d ms of the connect timeout", elapsedMs, epsilonMs));
100+
assertFalse(connectTimeout > elapsedMs,
101+
format("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d", connectTimeout, elapsedMs));
102+
assertTrue(elapsedMs - connectTimeout < maximumAcceptableTimeoutOvershoot,
103+
format("Connection timeout overshoot time %d ms should be within %d ms", elapsedMs - connectTimeout,
104+
maximumAcceptableTimeoutOvershoot));
105+
106+
SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult();
107+
assertNotNull(actualSpySocketChannel, "SocketChannel was not opened");
108+
verify(actualSpySocketChannel, atLeast(1)).close();
88109
}
89110
}
90111

91112
@ParameterizedTest
92113
@ValueSource(ints = {0, 500, 1000, 2000})
93-
void shouldEstablishConnection(final int connectTimeout) throws IOException {
114+
void shouldEstablishConnection(final int connectTimeout) throws IOException, InterruptedException {
94115
//given
95-
try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) {
116+
try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
117+
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
118+
SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
119+
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);
120+
96121
StreamFactory streamFactory = factory.create(SocketSettings.builder()
97122
.connectTimeout(connectTimeout, TimeUnit.MILLISECONDS)
98123
.build(), SSL_SETTINGS);
99124

100-
Stream stream = streamFactory.create(new ServerAddress("localhost", port));
125+
Stream stream = streamFactory.create(new ServerAddress(serverSocket.getInetAddress(), port));
101126
try {
102127
//when
103128
stream.open(OperationContext.simpleOperationContext(
104129
new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeout))));
105130

106131
//then
132+
SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult();
133+
assertNotNull(actualSpySocketChannel, "SocketChannel was not opened");
134+
assertTrue(actualSpySocketChannel.isConnected());
135+
136+
// Wait to verify that socket was not closed by timeout.
137+
SECONDS.sleep(3);
138+
assertTrue(actualSpySocketChannel.isConnected());
107139
assertFalse(stream.isClosed());
108140
} finally {
109141
stream.close();
110142
}
111143
}
112144
}
145+
146+
private static final class SingleResultSpyCaptor<T> implements Answer<T> {
147+
private volatile T result = null;
148+
149+
public T getResult() {
150+
return result;
151+
}
152+
153+
@Override
154+
public T answer(InvocationOnMock invocationOnMock) throws Throwable {
155+
if (result != null) {
156+
fail(invocationOnMock.getMethod().getName() + " was called more then once");
157+
}
158+
T returnedValue = (T) invocationOnMock.callRealMethod();
159+
result = Mockito.spy(returnedValue);
160+
return result;
161+
}
162+
}
113163
}

0 commit comments

Comments
 (0)