From f664dc9a6ece4ba1674ccc3f716e9910e8cac427 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Fri, 12 Apr 2024 00:12:55 +0300 Subject: [PATCH] migrate ktor websocket transport to new API --- ...socket-transport-ktor-websocket-client.api | 21 ++ .../ktor-websocket-client/build.gradle.kts | 1 + .../client/KtorWebSocketClientTransport.kt | 180 +++++++++++++++ ...cket-transport-ktor-websocket-internal.api | 4 + .../ktor-websocket-internal/build.gradle.kts | 1 + .../internal/KtorWebSocketConnection.kt | 63 ++++++ ...socket-transport-ktor-websocket-server.api | 27 +++ .../ktor-websocket-server/build.gradle.kts | 1 + .../server/KtorWebSocketServerTransport.kt | 210 ++++++++++++++++++ .../server/WebSocketServerTransport.kt | 2 - .../server/CIOWebSocketTransportTest.kt | 2 + .../server/WebSocketTransportTest.kt | 19 ++ .../server/WebSocketTransportTests.kt | 12 +- 13 files changed, 538 insertions(+), 5 deletions(-) create mode 100644 rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt create mode 100644 rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt create mode 100644 rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt diff --git a/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api b/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api index 9aaa4d11..1761b00d 100644 --- a/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api +++ b/rsocket-transports/ktor-websocket-client/api/rsocket-transport-ktor-websocket-client.api @@ -1,3 +1,24 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport$Factory; + public abstract fun target (Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun httpEngine (Lio/ktor/client/engine/HttpClientEngine;Lkotlin/jvm/functions/Function1;)V + public abstract fun httpEngine (Lio/ktor/client/engine/HttpClientEngineFactory;Lkotlin/jvm/functions/Function1;)V + public abstract fun httpEngine (Lkotlin/jvm/functions/Function1;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder;Lio/ktor/client/engine/HttpClientEngine;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransportBuilder;Lio/ktor/client/engine/HttpClientEngineFactory;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public abstract fun webSocketsConfig (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/websocket/client/WebSocketClientTransportKt { public static final fun WebSocketClientTransport (Lio/ktor/client/engine/HttpClientEngineFactory;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;ZLkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; public static final fun WebSocketClientTransport (Lio/ktor/client/engine/HttpClientEngineFactory;Ljava/lang/String;ZLkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport; diff --git a/rsocket-transports/ktor-websocket-client/build.gradle.kts b/rsocket-transports/ktor-websocket-client/build.gradle.kts index 86ee3265..f9bcbf6a 100644 --- a/rsocket-transports/ktor-websocket-client/build.gradle.kts +++ b/rsocket-transports/ktor-websocket-client/build.gradle.kts @@ -30,6 +30,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(projects.rsocketTransportKtorWebsocketInternal) + implementation(projects.rsocketInternalIo) api(projects.rsocketCore) api(libs.ktor.client.core) api(libs.ktor.client.websockets) diff --git a/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt new file mode 100644 index 00000000..c444ec65 --- /dev/null +++ b/rsocket-transports/ktor-websocket-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/client/KtorWebSocketClientTransport.kt @@ -0,0 +1,180 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.client + +import io.ktor.client.* +import io.ktor.client.engine.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.ktor.websocket.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface KtorWebSocketClientTransport : RSocketTransport { + public fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget + public fun target(urlString: String, request: HttpRequestBuilder.() -> Unit = {}): RSocketClientTarget + + public fun target( + method: HttpMethod = HttpMethod.Get, + host: String? = null, + port: Int? = null, + path: String? = null, + request: HttpRequestBuilder.() -> Unit = {}, + ): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::KtorWebSocketClientTransportBuilderImpl) +} + +public sealed interface KtorWebSocketClientTransportBuilder : RSocketTransportBuilder { + public fun httpEngine(configure: HttpClientEngineConfig.() -> Unit) + public fun httpEngine(engine: HttpClientEngine, configure: HttpClientEngineConfig.() -> Unit = {}) + public fun httpEngine(factory: HttpClientEngineFactory, configure: T.() -> Unit = {}) + + public fun webSocketsConfig(block: WebSockets.Config.() -> Unit) +} + +private class KtorWebSocketClientTransportBuilderImpl : KtorWebSocketClientTransportBuilder { + private var httpClientFactory: HttpClientFactory = HttpClientFactory.Default + private var webSocketsConfig: WebSockets.Config.() -> Unit = {} + + override fun httpEngine(configure: HttpClientEngineConfig.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromConfiguration(configure) + } + + override fun httpEngine(engine: HttpClientEngine, configure: HttpClientEngineConfig.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromEngine(engine, configure) + } + + override fun httpEngine(factory: HttpClientEngineFactory, configure: T.() -> Unit) { + this.httpClientFactory = HttpClientFactory.FromFactory(factory, configure) + } + + override fun webSocketsConfig(block: WebSockets.Config.() -> Unit) { + this.webSocketsConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorWebSocketClientTransport { + val httpClient = httpClientFactory.createHttpClient { + install(WebSockets, webSocketsConfig) + } + // only dispatcher of a client is used - it looks like it's Dispatchers.IO now + val newContext = context.supervisorContext() + (httpClient.coroutineContext[ContinuationInterceptor] ?: EmptyCoroutineContext) + val newJob = newContext.job + val httpClientJob = httpClient.coroutineContext.job + + httpClientJob.invokeOnCompletion { newJob.cancel("HttpClient closed", it) } + newJob.invokeOnCompletion { httpClientJob.cancel("KtorWebSocketClientTransport closed", it) } + + return KtorWebSocketClientTransportImpl( + coroutineContext = newContext, + httpClient = httpClient, + ) + } +} + +private class KtorWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val httpClient: HttpClient, +) : KtorWebSocketClientTransport { + override fun target(request: HttpRequestBuilder.() -> Unit): RSocketClientTarget = KtorWebSocketClientTargetImpl( + coroutineContext = coroutineContext, + httpClient = httpClient, + request = request + ) + + override fun target( + urlString: String, + request: HttpRequestBuilder.() -> Unit, + ): RSocketClientTarget = target( + method = HttpMethod.Get, host = null, port = null, path = null, + request = { + url.protocol = URLProtocol.WS + url.port = port + + url.takeFrom(urlString) + request() + }, + ) + + override fun target( + method: HttpMethod, + host: String?, + port: Int?, + path: String?, + request: HttpRequestBuilder.() -> Unit, + ): RSocketClientTarget = target { + this.method = method + url("ws", host, port, path) + request() + } +} + +private class KtorWebSocketClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val httpClient: HttpClient, + private val request: HttpRequestBuilder.() -> Unit, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + httpClient.webSocket(request) { + handler.handleKtorWebSocketConnection(this) + } + } +} + +private sealed class HttpClientFactory { + abstract fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient + + object Default : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(block) + } + + class FromConfiguration( + private val configure: HttpClientEngineConfig.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient { + engine(configure) + block() + } + } + + class FromEngine( + private val engine: HttpClientEngine, + private val configure: HttpClientEngineConfig.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(engine) { + engine(configure) + block() + } + } + + class FromFactory( + private val factory: HttpClientEngineFactory, + private val configure: T.() -> Unit, + ) : HttpClientFactory() { + override fun createHttpClient(block: HttpClientConfig<*>.() -> Unit): HttpClient = HttpClient(factory) { + engine(configure) + block() + } + } +} diff --git a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api index 9ca3776f..8607860c 100644 --- a/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api +++ b/rsocket-transports/ktor-websocket-internal/api/rsocket-transport-ktor-websocket-internal.api @@ -1,3 +1,7 @@ +public final class io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnectionKt { + public static final fun handleKtorWebSocketConnection (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lio/ktor/websocket/WebSocketSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/rsocket/kotlin/transport/ktor/websocket/internal/WebSocketConnection : io/rsocket/kotlin/Connection, kotlinx/coroutines/CoroutineScope { public fun (Lio/ktor/websocket/WebSocketSession;)V public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; diff --git a/rsocket-transports/ktor-websocket-internal/build.gradle.kts b/rsocket-transports/ktor-websocket-internal/build.gradle.kts index 0cb0129d..c3583f62 100644 --- a/rsocket-transports/ktor-websocket-internal/build.gradle.kts +++ b/rsocket-transports/ktor-websocket-internal/build.gradle.kts @@ -29,6 +29,7 @@ kotlin { sourceSets { commonMain.dependencies { + implementation(projects.rsocketInternalIo) api(projects.rsocketCore) api(libs.ktor.websockets) } diff --git a/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt new file mode 100644 index 00000000..05e351ae --- /dev/null +++ b/rsocket-transports/ktor-websocket-internal/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/internal/KtorWebSocketConnection.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.internal + +import io.ktor.utils.io.core.* +import io.ktor.websocket.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +@RSocketTransportApi +public suspend fun RSocketConnectionHandler.handleKtorWebSocketConnection(webSocketSession: WebSocketSession): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + + val senderJob = launch { + while (true) webSocketSession.send(outboundQueue.dequeueFrame()?.readBytes() ?: break) + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(KtorWebSocketConnection(outboundQueue, webSocketSession.incoming)) + } finally { + webSocketSession.incoming.cancel() + outboundQueue.close() + withContext(NonCancellable) { + senderJob.join() // await all frames sent + webSocketSession.close() + webSocketSession.coroutineContext.job.join() + } + } +} + +@RSocketTransportApi +private class KtorWebSocketConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + val frame = inbound.receiveCatching().getOrNull() ?: return null + return ByteReadPacket(frame.data) + } +} diff --git a/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api b/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api index 4c9b6cde..7dece198 100644 --- a/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api +++ b/rsocket-transports/ktor-websocket-server/api/rsocket-transport-ktor-websocket-server.api @@ -1,3 +1,30 @@ +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getConnectors ()Ljava/util/List; + public abstract fun getPath ()Ljava/lang/String; + public abstract fun getProtocol ()Ljava/lang/String; +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport$Factory; + public abstract fun target (Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/util/List;Ljava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun httpEngine (Lio/ktor/server/engine/ApplicationEngineFactory;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun httpEngine$default (Lio/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransportBuilder;Lio/ktor/server/engine/ApplicationEngineFactory;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public abstract fun webSocketsConfig (Lkotlin/jvm/functions/Function1;)V +} + public final class io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransportKt { public static final fun WebSocketServerTransport (Lio/ktor/server/engine/ApplicationEngineFactory;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ServerTransport; public static final fun WebSocketServerTransport (Lio/ktor/server/engine/ApplicationEngineFactory;[Lio/ktor/server/engine/EngineConnectorConfig;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ServerTransport; diff --git a/rsocket-transports/ktor-websocket-server/build.gradle.kts b/rsocket-transports/ktor-websocket-server/build.gradle.kts index a5868053..157ab04f 100644 --- a/rsocket-transports/ktor-websocket-server/build.gradle.kts +++ b/rsocket-transports/ktor-websocket-server/build.gradle.kts @@ -29,6 +29,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(projects.rsocketTransportKtorWebsocketInternal) + implementation(projects.rsocketInternalIo) api(projects.rsocketCore) api(libs.ktor.server.host.common) api(libs.ktor.server.websockets) diff --git a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt new file mode 100644 index 00000000..e8d66e6f --- /dev/null +++ b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/KtorWebSocketServerTransport.kt @@ -0,0 +1,210 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.ktor.websocket.server + +import io.ktor.server.application.* +import io.ktor.server.engine.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.ktor.websocket.internal.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface KtorWebSocketServerInstance : RSocketServerInstance { + public val connectors: List + public val path: String + public val protocol: String? +} + +public sealed interface KtorWebSocketServerTransport : RSocketTransport { + + public fun target( + host: String = "0.0.0.0", + port: Int = 80, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public fun target( + path: String = "", + protocol: String? = null, + connectorBuilder: EngineConnectorBuilder.() -> Unit, + ): RSocketServerTarget + + public fun target( + connector: EngineConnectorConfig, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public fun target( + connectors: List, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::KtorWebSocketServerTransportBuilderImpl) +} + +public sealed interface KtorWebSocketServerTransportBuilder : RSocketTransportBuilder { + public fun httpEngine( + factory: ApplicationEngineFactory, + configure: T.() -> Unit = {}, + ) + + public fun webSocketsConfig(block: WebSockets.WebSocketOptions.() -> Unit) +} + +private class KtorWebSocketServerTransportBuilderImpl : KtorWebSocketServerTransportBuilder { + private var httpServerFactory: HttpServerFactory<*, *>? = null + private var webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit = {} + + override fun httpEngine( + factory: ApplicationEngineFactory, + configure: T.() -> Unit, + ) { + this.httpServerFactory = HttpServerFactory(factory, configure) + } + + override fun webSocketsConfig(block: WebSockets.WebSocketOptions.() -> Unit) { + this.webSocketsConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): KtorWebSocketServerTransport = KtorWebSocketServerTransportImpl( + // we always add IO - as it's the best choice here, server will use it's own dispatcher anyway + coroutineContext = context.supervisorContext() + Dispatchers.IO, + factory = requireNotNull(httpServerFactory) { "httpEngine is required" }, + webSocketsConfig = webSocketsConfig, + ) +} + +private class KtorWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val factory: HttpServerFactory<*, *>, + private val webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit, +) : KtorWebSocketServerTransport { + override fun target( + connectors: List, + path: String, + protocol: String?, + ): RSocketServerTarget = KtorWebSocketServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + factory = factory, + webSocketsConfig = webSocketsConfig, + connectors = connectors, + path = path, + protocol = protocol + ) + + override fun target( + host: String, + port: Int, + path: String, + protocol: String?, + ): RSocketServerTarget = target(path, protocol) { + this.host = host + this.port = port + } + + override fun target( + path: String, + protocol: String?, + connectorBuilder: EngineConnectorBuilder.() -> Unit, + ): RSocketServerTarget = target(EngineConnectorBuilder().apply(connectorBuilder), path, protocol) + + override fun target( + connector: EngineConnectorConfig, + path: String, + protocol: String?, + ): RSocketServerTarget = target(listOf(connector), path, protocol) +} + +private class KtorWebSocketServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val factory: HttpServerFactory<*, *>, + private val webSocketsConfig: WebSockets.WebSocketOptions.() -> Unit, + private val connectors: List, + private val path: String, + private val protocol: String?, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): KtorWebSocketServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val engine = createServerEngine(handler) + val resolvedConnectors = startServerEngine(engine) + + return KtorWebSocketServerInstanceImpl( + coroutineContext = engine.environment.parentCoroutineContext, + connectors = resolvedConnectors, + path = path, + protocol = protocol + ) + } + + // parentCoroutineContext is the context of server instance + @RSocketTransportApi + private fun createServerEngine(handler: RSocketConnectionHandler): ApplicationEngine = factory.createServer( + applicationEngineEnvironment { + val target = this@KtorWebSocketServerTargetImpl + parentCoroutineContext = target.coroutineContext.childContext() + connectors.addAll(target.connectors) + module { + install(WebSockets, webSocketsConfig) + routing { + webSocket(target.path, target.protocol) { + handler.handleKtorWebSocketConnection(this) + } + } + } + } + ) + + @OptIn(ExperimentalCoroutinesApi::class) + private suspend fun startServerEngine( + applicationEngine: ApplicationEngine, + ): List = launchCoroutine { cont -> + applicationEngine.start().stopServerOnCancellation() + cont.resume(applicationEngine.resolvedConnectors()) { + // will cause stopping of the server + applicationEngine.environment.parentCoroutineContext.job.cancel("Cancelled", it) + } + } +} + +private class KtorWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val connectors: List, + override val path: String, + override val protocol: String?, +) : KtorWebSocketServerInstance + +private class HttpServerFactory( + private val factory: ApplicationEngineFactory, + private val configure: T.() -> Unit = {}, +) { + @RSocketTransportApi + fun createServer(environment: ApplicationEngineEnvironment): ApplicationEngine { + return factory.create(environment, configure) + } +} diff --git a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt index 24c2278c..28ccb4c7 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketServerTransport.kt @@ -24,8 +24,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.websocket.internal.* -//TODO: will be reworked later with transport API rework - @Suppress("FunctionName") public fun WebSocketServerTransport( engineFactory: ApplicationEngineFactory, diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt index 850b3cd0..82b880da 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt @@ -20,3 +20,5 @@ import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO class CIOWebSocketTransportTest : WebSocketTransportTest(ClientCIO, ServerCIO) + +class CIOKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, ServerCIO) diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt index ffd26b12..088f6cda 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt @@ -35,3 +35,22 @@ abstract class WebSocketTransportTest( ) } } + +abstract class KtorWebSocketTransportTest( + private val clientEngine: HttpClientEngineFactory<*>, + private val serverEngine: ApplicationEngineFactory<*, *>, +) : TransportTest() { + override suspend fun before() { + val server = startServer( + KtorWebSocketServerTransport(testContext) { + httpEngine(serverEngine) + }.target(port = 0) + ) + val port = server.connectors.single().port + client = connectClient( + KtorWebSocketClientTransport(testContext) { + httpEngine(clientEngine) + }.target(port = port) + ) + } +} diff --git a/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt b/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt index bb17b9a2..84bfc2f3 100644 --- a/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt +++ b/rsocket-transports/ktor-websocket-server/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTests.kt @@ -22,8 +22,14 @@ import io.ktor.server.netty.* import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO -class OkHttpClientWebSocketTransportTest : WebSocketTransportTest(OkHttp, ServerCIO) +//class OkHttpClientWebSocketTransportTest : WebSocketTransportTest(OkHttp, ServerCIO) +// +//class NettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Netty) +// +//class JettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Jetty) -class NettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Netty) +class OkHttpClientKtorWebSocketTransportTest : KtorWebSocketTransportTest(OkHttp, ServerCIO) -class JettyServerWebSocketTransportTest : WebSocketTransportTest(ClientCIO, Jetty) +class NettyServerKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, Netty) + +class JettyServerKtorWebSocketTransportTest : KtorWebSocketTransportTest(ClientCIO, Jetty)