diff --git a/mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForward.kt b/mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForward.kt new file mode 100644 index 00000000..e9d49d37 --- /dev/null +++ b/mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForward.kt @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Mamoe Technologies and contributors. + * + * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. + * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. + * + * https://github.com/mamoe/mirai/blob/master/LICENSE + */ + +package net.mamoe.mirai.api.http.adapter.http.plugin + +import io.ktor.http.* +import io.ktor.http.content.* +import io.ktor.server.application.* +import io.ktor.server.plugins.* +import io.ktor.server.request.* +import io.ktor.util.* +import io.ktor.util.pipeline.* +import io.ktor.util.reflect.* +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.serializer + + +internal val HttpForwardAttributeKey = AttributeKey("HttpForward") +internal val HttpForwardPhase = PipelinePhase("Forward") +val HttpForward = createApplicationPlugin("HttpForward", ::HttpForwardConfig) { + application.insertPhaseAfter(ApplicationCallPipeline.Call, HttpForwardPhase) + + application.intercept(HttpForwardPhase) { + val forwardContext = call.attributes.getOrNull(HttpForwardAttributeKey) + if (forwardContext != null && !forwardContext.forwarded) { + forwardContext.forwarded = true + forwardContext.convertors = this@createApplicationPlugin.pluginConfig.getConvertors() + application.execute(ApplicationForwardCall(call, forwardContext)) + } + } +} + +typealias BodyConvertor = (Any, TypeInfo) -> Any? + +class HttpForwardConfig { + private val convertors: MutableList = mutableListOf(DefaultBodyConvertor) + + fun addConvertor(convertor: BodyConvertor) { + convertors.add(convertor) + } + + fun getConvertors(): List = convertors + + @OptIn(InternalSerializationApi::class) + fun jsonElementBodyConvertor(json: Json) { + addConvertor { body, typeInfo -> + val b = if (body == NullBody) JsonNull else body + when { + b !is JsonElement -> null + typeInfo.type == String::class -> json.encodeToString(b) + else -> json.decodeFromJsonElement(typeInfo.type.serializer(), b) + } + } + } +} + +val DefaultBodyConvertor: (Any, TypeInfo) -> Any? = { body, typeInfo -> + if (typeInfo.type.isInstance(body)) body else null +} + +internal data class HttpForwardContext(val router: String, val body: Any?) { + var forwarded = false + var convertors = emptyList() +} + +fun ApplicationCall.forward(forward: String) { + attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, null)) +} + +fun ApplicationCall.forward(forward: String, body: Any?) { + attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, body ?: NullBody)) +} + +internal fun forwardReceivePipeline(convertors: List, body: Any): ApplicationReceivePipeline = + ApplicationReceivePipeline().apply { + intercept(ApplicationReceivePipeline.Transform) { + proceedWith(convertors.firstNotNullOfOrNull { it.invoke(body, context.receiveType) } + ?: throw CannotTransformContentToTypeException(context.receiveType.kotlinType!!)) + } + } + +internal class ApplicationForwardCall( + val delegate: ApplicationCall, val context: HttpForwardContext +) : ApplicationCall by delegate { + override val request: ApplicationRequest = DelegateApplicationRequest(this, context.router, context.body) +} + +internal class DelegateApplicationRequest( + override val call: ApplicationForwardCall, forward: String, body: Any? +) : ApplicationRequest by call.delegate.request { + private val _pipeline by lazy { + body?.let { forwardReceivePipeline(call.context.convertors, it) } ?: call.delegate.request.pipeline + } + override val local = DelegateRequestConnectionPoint(call.delegate.request.local, forward) + override val pipeline: ApplicationReceivePipeline = _pipeline +} + +internal class DelegateRequestConnectionPoint( + private val delegate: RequestConnectionPoint, override val uri: String +) : RequestConnectionPoint by delegate \ No newline at end of file diff --git a/mirai-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForwardTest.kt b/mirai-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForwardTest.kt new file mode 100644 index 00000000..99634bea --- /dev/null +++ b/mirai-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForwardTest.kt @@ -0,0 +1,170 @@ +/* + * Copyright 2023 Mamoe Technologies and contributors. + * + * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. + * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. + * + * https://github.com/mamoe/mirai/blob/master/LICENSE + */ +package net.mamoe.mirai.api.http.adapter.http.plugin + +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.application.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.plugins.doublereceive.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.testing.* +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO +import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.NudgeDTO +import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer +import kotlin.test.Test +import kotlin.test.assertEquals + +class HttpForwardTest { + + @Test + fun testGetRequestForward() = testApplication { + routing { + get("/test") { + call.forward("/forward") + } + + get("/forward") { + call.respondText(call.parameters["key"] ?: "null") + } + } + + client.get("/test") { + parameter("key", "value") + }.also { + assertEquals(HttpStatusCode.OK, it.status) + assertEquals("value", it.bodyAsText()) + } + } + + @Test + fun testPostRequestForwardReceiveBody() = testApplication { + install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } + + routing { + post("/test") { + call.forward("/forward") + } + + post("/forward") { + val receive = call.receive() + call.respondText(receive.target.toString()) + } + } + + client.post("/test") { + contentType(ContentType.Application.Json) + setBody("""{"target":123}""") + }.also { + assertEquals(HttpStatusCode.OK, it.status) + assertEquals("123", it.bodyAsText()) + } + } + + @Test + fun testPostRequestForwardDoubleReceiveBody() = testApplication { + install(DoubleReceive) + install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } + + routing { + post("/test") { + val receive = call.receive() + assertEquals(123, receive.target) + call.forward("/forward") + } + + post("/forward") { + val receive = call.receive() + call.respondText(receive.target.toString()) + } + } + + client.post("/test") { + contentType(ContentType.Application.Json) + setBody("""{"target":123}""") + }.also { + assertEquals(HttpStatusCode.OK, it.status) + assertEquals("123", it.bodyAsText()) + } + } + + @Test + fun testPostRequestForwardResetBody() = testApplication { + install(DoubleReceive) + install(HttpRouterMonitor) + install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } + + routing { + post("/test") { + val receive = call.receive() + assertEquals(123, receive.target) + call.forward("/forward", NudgeDTO(321, 321, "kind")) + } + + post("/forward") { + val receive = call.receive() + call.respondText(receive.target.toString()) + } + } + + client.post("/test") { + contentType(ContentType.Application.Json) + setBody("""{"target":123}""") + }.also { + assertEquals(HttpStatusCode.OK, it.status) + assertEquals("321", it.bodyAsText()) + } + } + + + @Serializable + private data class NestedDto( + val router: String, + val body: JsonElement, + ) + + @Test + fun testPostRequestForwardNestedBody() = testApplication { + val json = BuiltinJsonSerializer.buildJson() + + install(DoubleReceive) + install(HttpRouterMonitor) + install(ContentNegotiation) { json(json) } + install(HttpForward) { jsonElementBodyConvertor(json) } + + routing { + post("/test") { + val receive = call.receive() + assertEquals("/forward", receive.router) + call.forward("/forward", receive.body) + + call.respond(HttpStatusCode.OK) + } + + post("/forward") { + val receive = call.receive() + call.respondText(receive.target.toString()) + } + } + + client.post("/test") { + contentType(ContentType.Application.Json) + setBody("""{"router":"/forward","body":{"target":321}}""") + }.also { + assertEquals(HttpStatusCode.OK, it.status) + assertEquals("321", it.bodyAsText()) + } + } +} \ No newline at end of file