26
26
import org .junit .jupiter .api .BeforeEach ;
27
27
import org .junit .jupiter .params .ParameterizedTest ;
28
28
import org .junit .jupiter .params .provider .ValueSource ;
29
+ import org .mockito .MockedStatic ;
30
+ import org .mockito .Mockito ;
31
+ import org .mockito .invocation .InvocationOnMock ;
32
+ import org .mockito .stubbing .Answer ;
29
33
30
34
import java .io .IOException ;
31
35
import java .net .ServerSocket ;
32
36
import java .nio .channels .InterruptedByTimeoutException ;
37
+ import java .nio .channels .SocketChannel ;
33
38
import java .util .concurrent .TimeUnit ;
34
39
35
- import static com .mongodb .assertions .Assertions .assertTrue ;
40
+ import static java .lang .String .format ;
41
+ import static java .util .concurrent .TimeUnit .SECONDS ;
36
42
import static org .junit .Assert .assertThrows ;
37
43
import static org .junit .jupiter .api .Assertions .assertFalse ;
38
44
import static org .junit .jupiter .api .Assertions .assertInstanceOf ;
45
+ import static org .junit .jupiter .api .Assertions .assertNotNull ;
46
+ import static org .junit .jupiter .api .Assertions .assertTrue ;
47
+ import static org .junit .jupiter .api .Assertions .fail ;
48
+ import static org .mockito .Mockito .atLeast ;
49
+ import static org .mockito .Mockito .verify ;
39
50
40
51
class TlsChannelStreamFunctionalTest {
41
52
private static final SslSettings SSL_SETTINGS = SslSettings .builder ().enabled (true ).build ();
@@ -50,16 +61,22 @@ void setUp() throws IOException {
50
61
}
51
62
52
63
@ AfterEach
64
+ @ SuppressWarnings ("try" )
53
65
void cleanUp () throws IOException {
54
- try (ServerSocket ignore = serverSocket ) {
66
+ try (ServerSocket ignored = serverSocket ) {
67
+ //ignored
55
68
}
56
69
}
57
70
58
71
@ ParameterizedTest
59
72
@ ValueSource (ints = {500 , 1000 , 2000 })
60
- void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires (final int connectTimeout ) {
73
+ void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires (final int connectTimeout ) throws IOException {
61
74
//given
62
- try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ())) {
75
+ try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
76
+ MockedStatic <SocketChannel > socketChannelMockedStatic = Mockito .mockStatic (SocketChannel .class )) {
77
+ SingleResultSpyCaptor <SocketChannel > singleResultSpyCaptor = new SingleResultSpyCaptor <>();
78
+ socketChannelMockedStatic .when (SocketChannel ::open ).thenAnswer (singleResultSpyCaptor );
79
+
63
80
StreamFactory streamFactory = factory .create (SocketSettings .builder ()
64
81
.connectTimeout (connectTimeout , TimeUnit .MILLISECONDS )
65
82
.build (), SSL_SETTINGS );
@@ -75,39 +92,72 @@ void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final in
75
92
76
93
//then
77
94
long elapsedMs = TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - connectOpenStart );
78
- long diff = elapsedMs - connectTimeout ;
79
- // Allowed difference, with test overhead setup is 300MS.
80
- int epsilonMs = 300 ;
95
+ // Allow for some timing imprecision due to test overhead.
96
+ int maximumAcceptableTimeoutOvershoot = 300 ;
81
97
82
98
assertInstanceOf (InterruptedByTimeoutException .class , mongoSocketOpenException .getCause (),
83
99
"Actual cause: " + mongoSocketOpenException .getCause ());
84
- assertFalse (diff < 0 ,
85
- String .format ("Connection timed-out sooner than expected. Difference: %d ms" , diff ));
86
- assertTrue (diff < epsilonMs ,
87
- String .format ("Elapsed time %d ms should be within %d ms of the connect timeout" , elapsedMs , epsilonMs ));
100
+ assertFalse (connectTimeout > elapsedMs ,
101
+ format ("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d" , connectTimeout , elapsedMs ));
102
+ assertTrue (elapsedMs - connectTimeout < maximumAcceptableTimeoutOvershoot ,
103
+ format ("Connection timeout overshoot time %d ms should be within %d ms" , elapsedMs - connectTimeout ,
104
+ maximumAcceptableTimeoutOvershoot ));
105
+
106
+ SocketChannel actualSpySocketChannel = singleResultSpyCaptor .getResult ();
107
+ assertNotNull (actualSpySocketChannel , "SocketChannel was not opened" );
108
+ verify (actualSpySocketChannel , atLeast (1 )).close ();
88
109
}
89
110
}
90
111
91
112
@ ParameterizedTest
92
113
@ ValueSource (ints = {0 , 500 , 1000 , 2000 })
93
- void shouldEstablishConnection (final int connectTimeout ) throws IOException {
114
+ void shouldEstablishConnection (final int connectTimeout ) throws IOException , InterruptedException {
94
115
//given
95
- try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ())) {
116
+ try (TlsChannelStreamFactoryFactory factory = new TlsChannelStreamFactoryFactory (new DefaultInetAddressResolver ());
117
+ MockedStatic <SocketChannel > socketChannelMockedStatic = Mockito .mockStatic (SocketChannel .class )) {
118
+ SingleResultSpyCaptor <SocketChannel > singleResultSpyCaptor = new SingleResultSpyCaptor <>();
119
+ socketChannelMockedStatic .when (SocketChannel ::open ).thenAnswer (singleResultSpyCaptor );
120
+
96
121
StreamFactory streamFactory = factory .create (SocketSettings .builder ()
97
122
.connectTimeout (connectTimeout , TimeUnit .MILLISECONDS )
98
123
.build (), SSL_SETTINGS );
99
124
100
- Stream stream = streamFactory .create (new ServerAddress ("localhost" , port ));
125
+ Stream stream = streamFactory .create (new ServerAddress (serverSocket . getInetAddress () , port ));
101
126
try {
102
127
//when
103
128
stream .open (OperationContext .simpleOperationContext (
104
129
new TimeoutContext (TimeoutSettings .DEFAULT .withConnectTimeoutMS (connectTimeout ))));
105
130
106
131
//then
132
+ SocketChannel actualSpySocketChannel = singleResultSpyCaptor .getResult ();
133
+ assertNotNull (actualSpySocketChannel , "SocketChannel was not opened" );
134
+ assertTrue (actualSpySocketChannel .isConnected ());
135
+
136
+ // Wait to verify that socket was not closed by timeout.
137
+ SECONDS .sleep (3 );
138
+ assertTrue (actualSpySocketChannel .isConnected ());
107
139
assertFalse (stream .isClosed ());
108
140
} finally {
109
141
stream .close ();
110
142
}
111
143
}
112
144
}
145
+
146
+ private static final class SingleResultSpyCaptor <T > implements Answer <T > {
147
+ private volatile T result = null ;
148
+
149
+ public T getResult () {
150
+ return result ;
151
+ }
152
+
153
+ @ Override
154
+ public T answer (InvocationOnMock invocationOnMock ) throws Throwable {
155
+ if (result != null ) {
156
+ fail (invocationOnMock .getMethod ().getName () + " was called more then once" );
157
+ }
158
+ T returnedValue = (T ) invocationOnMock .callRealMethod ();
159
+ result = Mockito .spy (returnedValue );
160
+ return result ;
161
+ }
162
+ }
113
163
}
0 commit comments