diff --git a/src/main/kotlin/org/phoenixframework/PhxPush.kt b/src/main/kotlin/org/phoenixframework/PhxPush.kt index 5666119..7f23bab 100644 --- a/src/main/kotlin/org/phoenixframework/PhxPush.kt +++ b/src/main/kotlin/org/phoenixframework/PhxPush.kt @@ -184,8 +184,8 @@ class PhxPush( val mutPayload = payload.toMutableMap() mutPayload["status"] = status - refEvent?.let { - val message = PhxMessage(it, "", "", mutPayload) + refEvent?.let { safeRefEvent -> + val message = PhxMessage(event = safeRefEvent, payload = mutPayload) this.channel.trigger(message) } } diff --git a/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt b/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt index 879fe68..a32d76d 100644 --- a/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/PhxChannelTest.kt @@ -1,17 +1,25 @@ package org.phoenixframework import com.google.common.truth.Truth.assertThat +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.WebSocket +import okhttp3.WebSocketListener import org.junit.Before import org.junit.Test +import org.mockito.ArgumentMatchers import org.mockito.Mockito +import org.mockito.Mockito.`when` import org.mockito.MockitoAnnotations import org.mockito.Spy import java.util.concurrent.CompletableFuture import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger class PhxChannelTest { private val defaultRef = "1" + private val topic = "topic" @Spy var socket: PhxSocket = PhxSocket("http://localhost:4000/socket/websocket") @@ -23,7 +31,7 @@ class PhxChannelTest { Mockito.doReturn(defaultRef).`when`(socket).makeRef() socket.timeout = 1234 - channel = PhxChannel("topic", hashMapOf("one" to "two"), socket) + channel = PhxChannel(topic, hashMapOf("one" to "two"), socket) } @@ -149,4 +157,46 @@ class PhxChannelTest { CompletableFuture.allOf(f1, f3).get(10, TimeUnit.SECONDS) } + + @Test + fun `issue 36 - verify timeouts remove bindings`() { + // mock okhttp to get isConnected to return true for the socket + val mockOkHttp = Mockito.mock(OkHttpClient::class.java) + val mockSocket = Mockito.mock(WebSocket::class.java) + `when`(mockOkHttp.newWebSocket(ArgumentMatchers.any(Request::class.java), ArgumentMatchers.any(WebSocketListener::class.java))).thenReturn(mockSocket) + + // local mocks for this test + val localSocket = Mockito.spy(PhxSocket(url = "http://localhost:4000/socket/websocket", client = mockOkHttp)) + val localChannel = PhxChannel(topic, hashMapOf("one" to "two"), localSocket) + + // setup makeRef so it increments + val refCounter = AtomicInteger(1) + Mockito.doAnswer { + refCounter.getAndIncrement().toString() + }.`when`(localSocket).makeRef() + + //connect the socket + localSocket.connect() + + //join the channel + val joinPush = localChannel.join() + localChannel.trigger(PhxMessage( + ref = joinPush.ref!!, + joinRef = joinPush.ref!!, + event = PhxChannel.PhxEvent.REPLY.value, + topic = topic, + payload = mutableMapOf("status" to "ok"))) + + //get bindings + val originalBindingsSize = localChannel.bindings.size + val pushCount = 100 + repeat(pushCount) { + localChannel.push("some-event", mutableMapOf(), timeout = 500) + } + //verify binding count before timeouts + assertThat(localChannel.bindings.size).isEqualTo(originalBindingsSize + pushCount) + Thread.sleep(1000) + //verify binding count after timeouts + assertThat(localChannel.bindings.size).isEqualTo(originalBindingsSize) + } }