11package org.phoenixframework
22
33import com.google.common.truth.Truth.assertThat
4+ import okhttp3.OkHttpClient
5+ import okhttp3.Request
6+ import okhttp3.WebSocket
7+ import okhttp3.WebSocketListener
48import org.junit.Before
59import org.junit.Test
10+ import org.mockito.ArgumentMatchers
611import org.mockito.Mockito
12+ import org.mockito.Mockito.`when`
713import org.mockito.MockitoAnnotations
814import org.mockito.Spy
915import java.util.concurrent.CompletableFuture
1016import java.util.concurrent.TimeUnit
17+ import java.util.concurrent.atomic.AtomicInteger
1118
1219class PhxChannelTest {
1320
1421 private val defaultRef = " 1"
22+ private val topic = " topic"
1523
1624 @Spy
1725 var socket: PhxSocket = PhxSocket (" http://localhost:4000/socket/websocket" )
@@ -23,7 +31,7 @@ class PhxChannelTest {
2331 Mockito .doReturn(defaultRef).`when `(socket).makeRef()
2432
2533 socket.timeout = 1234
26- channel = PhxChannel (" topic" , hashMapOf(" one" to " two" ), socket)
34+ channel = PhxChannel (topic, hashMapOf(" one" to " two" ), socket)
2735 }
2836
2937
@@ -149,4 +157,46 @@ class PhxChannelTest {
149157
150158 CompletableFuture .allOf(f1, f3).get(10 , TimeUnit .SECONDS )
151159 }
160+
161+ @Test
162+ fun `issue 36 - verify timeouts remove bindings` () {
163+ // mock okhttp to get isConnected to return true for the socket
164+ val mockOkHttp = Mockito .mock(OkHttpClient ::class .java)
165+ val mockSocket = Mockito .mock(WebSocket ::class .java)
166+ `when `(mockOkHttp.newWebSocket(ArgumentMatchers .any(Request ::class .java), ArgumentMatchers .any(WebSocketListener ::class .java))).thenReturn(mockSocket)
167+
168+ // local mocks for this test
169+ val localSocket = Mockito .spy(PhxSocket (url = " http://localhost:4000/socket/websocket" , client = mockOkHttp))
170+ val localChannel = PhxChannel (topic, hashMapOf(" one" to " two" ), localSocket)
171+
172+ // setup makeRef so it increments
173+ val refCounter = AtomicInteger (1 )
174+ Mockito .doAnswer {
175+ refCounter.getAndIncrement().toString()
176+ }.`when `(localSocket).makeRef()
177+
178+ // connect the socket
179+ localSocket.connect()
180+
181+ // join the channel
182+ val joinPush = localChannel.join()
183+ localChannel.trigger(PhxMessage (
184+ ref = joinPush.ref!! ,
185+ joinRef = joinPush.ref!! ,
186+ event = PhxChannel .PhxEvent .REPLY .value,
187+ topic = topic,
188+ payload = mutableMapOf (" status" to " ok" )))
189+
190+ // get bindings
191+ val originalBindingsSize = localChannel.bindings.size
192+ val pushCount = 100
193+ repeat(pushCount) {
194+ localChannel.push(" some-event" , mutableMapOf (), timeout = 500 )
195+ }
196+ // verify binding count before timeouts
197+ assertThat(localChannel.bindings.size).isEqualTo(originalBindingsSize + pushCount)
198+ Thread .sleep(1000 )
199+ // verify binding count after timeouts
200+ assertThat(localChannel.bindings.size).isEqualTo(originalBindingsSize)
201+ }
152202}
0 commit comments