From 1a1941d63d0c7b219434a5212f99d3ae41060b6c Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Fri, 12 Apr 2024 00:00:28 +0300 Subject: [PATCH] migrate ktor tcp transport to new API --- .../api/rsocket-transport-ktor-tcp.api | 40 +++++ .../ktor/tcp/KtorTcpClientTransport.kt | 107 ++++++++++++ .../transport/ktor/tcp/KtorTcpConnection.kt | 111 ++++++++++++ .../transport/ktor/tcp/KtorTcpSelector.kt | 49 ++++++ .../ktor/tcp/KtorTcpServerTransport.kt | 161 ++++++++++++++++++ .../transport/ktor/tcp/TcpClientTransport.kt | 9 +- .../transport/ktor/tcp/TcpServerTransport.kt | 2 +- .../transport/ktor/tcp/TcpServerTest.kt | 27 ++- .../transport/ktor/tcp/TcpTransportTest.kt | 35 ++++ .../transport/ktor/tcp/defaultDispatcher.kt | 21 --- .../transport/ktor/tcp/defaultDispatcher.kt | 21 --- 11 files changed, 517 insertions(+), 66 deletions(-) create mode 100644 rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt create mode 100644 rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt create mode 100644 rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt create mode 100644 rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt delete mode 100644 rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt delete mode 100644 rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt diff --git a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api index 5833bcf1..d6153045 100644 --- a/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api +++ b/rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api @@ -1,3 +1,43 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory; + public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V + public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V + public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V + public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Lio/ktor/network/sockets/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory; + public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Lio/ktor/network/sockets/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V + public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V + public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V + public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransportKt { public static final fun TcpClientTransport (Lio/ktor/network/sockets/InetSocketAddress;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; public static final fun TcpClientTransport (Ljava/lang/String;ILkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt new file mode 100644 index 00000000..b4a44005 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport.kt @@ -0,0 +1,107 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface KtorTcpClientTransport : RSocketTransport { + public fun target(remoteAddress: SocketAddress): RSocketClientTarget + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::KtorTcpClientTransportBuilderImpl) +} + +public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) + + public fun selectorManagerDispatcher(context: CoroutineContext) + public fun selectorManager(manager: SelectorManager, manage: Boolean) + + public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) + + //TODO: TLS support +} + +private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.IO + private var selector: KtorTcpSelector? = null + private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {} + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) { + this.socketOptions = block + } + + override fun selectorManagerDispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.selector = KtorTcpSelector.FromContext(context) + } + + override fun selectorManager(manager: SelectorManager, manage: Boolean) { + this.selector = KtorTcpSelector.FromInstance(manager, manage) + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport { + val transportContext = context.supervisorContext() + dispatcher + return KtorTcpClientTransportImpl( + coroutineContext = transportContext, + socketOptions = socketOptions, + selectorManager = selector.createFor(transportContext) + ) + } +} + +private class KtorTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit, + private val selectorManager: SelectorManager, +) : KtorTcpClientTransport { + override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + socketOptions = socketOptions, + selectorManager = selectorManager, + remoteAddress = remoteAddress + ) + + override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port)) +} + +private class KtorTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit, + private val selectorManager: SelectorManager, + private val remoteAddress: SocketAddress, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions) + handler.handleKtorTcpConnection(socket) + } +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt new file mode 100644 index 00000000..b65c1077 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpConnection.kt @@ -0,0 +1,111 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.sockets.* +import io.ktor.utils.io.* +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +@RSocketTransportApi +internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + val inbound = channelForCloseable(Channel.BUFFERED) + + val readerJob = launch { + val input = socket.openReadChannel() + try { + while (true) inbound.send(input.readFrame() ?: break) + input.cancel(null) + } catch (cause: Throwable) { + input.cancel(cause) + throw cause + } + }.onCompletion { inbound.cancel() } + + val writerJob = launch { + val output = socket.openWriteChannel() + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + output.writeFrame(outboundQueue.dequeueFrame() ?: break) + while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break) + output.flush() + } + output.close(null) + } catch (cause: Throwable) { + output.close(cause) + throw cause + } + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(KtorTcpConnection(outboundQueue, inbound)) + } finally { + readerJob.cancel() + outboundQueue.close() // will cause `writerJob` completion + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + // await completion of read/write and then close socket + readerJob.join() + writerJob.join() + // close socket + socket.close() + socket.socketContext.join() + } + } +} + +@RSocketTransportApi +private class KtorTcpConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} + +private suspend fun ByteWriteChannel.writeFrame(frame: ByteReadPacket) { + val packet = buildPacket { + writeInt24(frame.remaining.toInt()) + writePacket(frame) + } + try { + writePacket(packet) + } catch (cause: Throwable) { + packet.close() + throw cause + } +} + +private suspend fun ByteReadChannel.readFrame(): ByteReadPacket? { + val lengthPacket = readRemaining(3) + if (lengthPacket.remaining == 0L) return null + return readPacket(lengthPacket.readInt24()) +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt new file mode 100644 index 00000000..6de116e1 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpSelector.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +internal sealed class KtorTcpSelector { + class FromContext(val context: CoroutineContext) : KtorTcpSelector() + class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector() +} + +internal fun KtorTcpSelector?.createFor(parentContext: CoroutineContext): SelectorManager { + val selectorManager: SelectorManager + val manage: Boolean + when (this) { + null -> { + selectorManager = SelectorManager(parentContext) + manage = true + } + + is KtorTcpSelector.FromContext -> { + selectorManager = SelectorManager(parentContext + context) + manage = true + } + + is KtorTcpSelector.FromInstance -> { + selectorManager = this.selectorManager + manage = this.manage + } + } + if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() } + return selectorManager +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt new file mode 100644 index 00000000..feb00bd9 --- /dev/null +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport.kt @@ -0,0 +1,161 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.tcp + +import io.ktor.network.selector.* +import io.ktor.network.sockets.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface KtorTcpServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress +} + +public sealed interface KtorTcpServerTransport : RSocketTransport { + public fun target(localAddress: SocketAddress? = null): RSocketServerTarget + public fun target(host: String = "0.0.0.0", port: Int = 0): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::KtorTcpServerTransportBuilderImpl) +} + +public sealed interface KtorTcpServerTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) + + public fun selectorManagerDispatcher(context: CoroutineContext) + public fun selectorManager(manager: SelectorManager, manage: Boolean) + + public fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) +} + +private class KtorTcpServerTransportBuilderImpl : KtorTcpServerTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.IO + private var selector: KtorTcpSelector? = null + private var socketOptions: SocketOptions.AcceptorOptions.() -> Unit = {} + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + override fun socketOptions(block: SocketOptions.AcceptorOptions.() -> Unit) { + this.socketOptions = block + } + + override fun selectorManagerDispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.selector = KtorTcpSelector.FromContext(context) + } + + override fun selectorManager(manager: SelectorManager, manage: Boolean) { + this.selector = KtorTcpSelector.FromInstance(manager, manage) + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorTcpServerTransport { + val transportContext = context.supervisorContext() + dispatcher + return KtorTcpServerTransportImpl( + coroutineContext = transportContext, + socketOptions = socketOptions, + selectorManager = selector.createFor(transportContext) + ) + } +} + +private class KtorTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, + private val selectorManager: SelectorManager, +) : KtorTcpServerTransport { + override fun target(localAddress: SocketAddress?): RSocketServerTarget = KtorTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + socketOptions = socketOptions, + selectorManager = selectorManager, + localAddress = localAddress + ) + + override fun target(host: String, port: Int): RSocketServerTarget = target(InetSocketAddress(host, port)) +} + +private class KtorTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val socketOptions: SocketOptions.AcceptorOptions.() -> Unit, + private val selectorManager: SelectorManager, + private val localAddress: SocketAddress?, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): KtorTcpServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + return startKtorTcpServer(this, bindSocket(), handler) + } + + @OptIn(ExperimentalCoroutinesApi::class) + private suspend fun bindSocket(): ServerSocket = launchCoroutine { cont -> + val socket = aSocket(selectorManager).tcp().bind(localAddress, socketOptions) + cont.resume(socket) { socket.close() } + } +} + +@RSocketTransportApi +private fun startKtorTcpServer( + scope: CoroutineScope, + serverSocket: ServerSocket, + handler: RSocketConnectionHandler, +): KtorTcpServerInstance { + val serverJob = scope.launch { + try { + // the failure of one connection should not stop all other connections + supervisorScope { + while (true) { + val socket = serverSocket.accept() + launch { handler.handleKtorTcpConnection(socket) } + } + } + } finally { + // even if it was cancelled, we still need to close socket and await it closure + withContext(NonCancellable) { + serverSocket.close() + serverSocket.socketContext.join() + } + } + } + return KtorTcpServerInstanceImpl( + coroutineContext = scope.coroutineContext + serverJob, + localAddress = serverSocket.localAddress + ) +} + +@RSocketTransportApi +private class KtorTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress, +) : KtorTcpServerInstance + +@Suppress("SuspendFunctionOnCoroutineScope") +private suspend inline fun CoroutineScope.launchCoroutine( + context: CoroutineContext = EmptyCoroutineContext, + crossinline block: suspend (CancellableContinuation) -> Unit, +): T = suspendCancellableCoroutine { cont -> + val job = launch(context) { block(cont) } + job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) } + cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) } +} diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt index cd3a208a..8194bd8d 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransport.kt @@ -26,11 +26,6 @@ import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlin.coroutines.* -//TODO user should close ClientTransport manually if there is no job provided in context - -//this dispatcher will be used, if no dispatcher were provided by user in client and server -internal expect val defaultDispatcher: CoroutineDispatcher - public fun TcpClientTransport( hostname: String, port: Int, context: CoroutineContext = EmptyCoroutineContext, @@ -42,10 +37,10 @@ public fun TcpClientTransport( remoteAddress: InetSocketAddress, context: CoroutineContext = EmptyCoroutineContext, intercept: (Socket) -> Socket = { it }, //f.e. for tls, which is currently supported by ktor only on JVM - configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {} + configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {}, ): ClientTransport { val transportJob = SupervisorJob(context[Job]) - val transportContext = defaultDispatcher + context + transportJob + CoroutineName("rSocket-tcp-client") + val transportContext = Dispatchers.IO + context + transportJob + CoroutineName("rSocket-tcp-client") val selector = SelectorManager(transportContext) Job(transportJob).invokeOnCompletion { selector.close() } return ClientTransport(transportContext) { diff --git a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt index 7282068f..f9be46d8 100644 --- a/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt +++ b/rsocket-transports/ktor-tcp/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTransport.kt @@ -41,7 +41,7 @@ public fun TcpServerTransport( configure: SocketOptions.AcceptorOptions.() -> Unit = {}, ): ServerTransport = ServerTransport { accept -> val serverSocketDeferred = CompletableDeferred() - val handlerJob = launch(defaultDispatcher + coroutineContext) { + val handlerJob = launch(Dispatchers.IO + coroutineContext) { SelectorManager(coroutineContext).use { selector -> aSocket(selector).tcp().bind(localAddress, configure).use { serverSocket -> serverSocketDeferred.complete(serverSocket) diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt index 95667dfb..df99dc02 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt @@ -16,7 +16,6 @@ package io.rsocket.kotlin.transport.ktor.tcp -import io.ktor.network.sockets.* import io.rsocket.kotlin.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* @@ -25,11 +24,9 @@ import kotlin.test.* class TcpServerTest : SuspendTest, TestWithLeakCheck { private val testJob = Job() private val testContext = testJob + TestExceptionHandler - private val serverTransport = TcpServerTransport() - private suspend fun clientTransport(server: TcpServer) = TcpClientTransport( - server.serverSocket.await().localAddress as InetSocketAddress, - testContext - ) + private val serverTransport = KtorTcpServerTransport(testContext).target() + private fun KtorTcpServerInstance.clientTransport() = + KtorTcpClientTransport(testContext).target(localAddress) override suspend fun after() { testJob.cancelAndJoin() @@ -37,13 +34,13 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { @Test fun testFailedConnection() = test { - val server = TestServer().bindIn(CoroutineScope(testContext), serverTransport) { + val server = TestServer().startServer(serverTransport) { if (config.setupPayload.data.readText() == "ok") { RSocketRequestHandler { requestResponse { it } } } else error("FAILED") - }.also { it.serverSocket.await() } + } suspend fun newClient(text: String) = TestConnector { connectionConfig { @@ -51,7 +48,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { payload(text) } } - }.connect(clientTransport(server)) + }.connect(server.clientTransport()) val client1 = newClient("ok") client1.requestResponse(payload("ok")).close() @@ -70,8 +67,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { assertFalse(client2.isActive) assertTrue(client3.isActive) - assertTrue(server.serverSocket.await().socketContext.isActive) - assertTrue(server.handlerJob.isActive) + assertTrue(server.isActive) client1.coroutineContext.job.cancelAndJoin() client2.coroutineContext.job.cancelAndJoin() @@ -81,13 +77,13 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { @Test fun testFailedHandler() = test { val handlers = mutableListOf() - val server = TestServer().bindIn(CoroutineScope(testContext), serverTransport) { + val server = TestServer().startServer(serverTransport) { RSocketRequestHandler { requestResponse { it } }.also { handlers += it } - }.also { it.serverSocket.await() } + } - suspend fun newClient() = TestConnector().connect(clientTransport(server)) + suspend fun newClient() = TestConnector().connect(server.clientTransport()) val client1 = newClient() @@ -118,8 +114,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { assertFalse(client2.isActive) assertTrue(client3.isActive) - assertTrue(server.serverSocket.await().socketContext.isActive) - assertTrue(server.handlerJob.isActive) + assertTrue(server.isActive) client1.coroutineContext.job.cancelAndJoin() client2.coroutineContext.job.cancelAndJoin() diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt index a55440e5..f3e49288 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt @@ -16,8 +16,21 @@ package io.rsocket.kotlin.transport.ktor.tcp +import io.ktor.network.selector.* import io.ktor.network.sockets.* import io.rsocket.kotlin.transport.tests.* +import kotlinx.coroutines.* + +// tcp on native works very bad when there is a big thread dispatcher with multiple selectors +// TBD: may be there is a bug in ktor, +// so we use a smaller dispatcher. +// why 2 threads? +// because on JVM, 1 thread is needed for selector manager to work (blocking), +// and so 1 will be used to execute everything +//@OptIn(ExperimentalCoroutinesApi::class) +//private val serverDispatcher = Dispatchers.IO.limitedParallelism(2) +//@OptIn(ExperimentalCoroutinesApi::class) +//private val clientDispatcher = Dispatchers.IO.limitedParallelism(2) class TcpTransportTest : TransportTest() { override suspend fun before() { @@ -25,3 +38,25 @@ class TcpTransportTest : TransportTest() { client = connectClient(TcpClientTransport(serverSocket.localAddress as InetSocketAddress, testContext)) } } + +class KtorTcpTransportTest : TransportTest() { + // a single SelectorManager for both client and server works much better in K/N + // in user code in most of the cases, only one SelectorManager will be created + private val selector = SelectorManager(Dispatchers.IO) + override suspend fun before() { + val server = startServer(KtorTcpServerTransport(testContext) { + selectorManager(selector, false) + }.target()) + client = connectClient(KtorTcpClientTransport(testContext) { + selectorManager(selector, false) + }.target(server.localAddress)) + } + + override suspend fun after() { + try { + super.after() + } finally { + selector.close() + } + } +} diff --git a/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt b/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt deleted file mode 100644 index aec80603..00000000 --- a/rsocket-transports/ktor-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.ktor.tcp - -import kotlinx.coroutines.* - -internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.IO diff --git a/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt b/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt deleted file mode 100644 index 5fa3a8e2..00000000 --- a/rsocket-transports/ktor-tcp/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/tcp/defaultDispatcher.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.transport.ktor.tcp - -import kotlinx.coroutines.* - -internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.Unconfined