-
Notifications
You must be signed in to change notification settings - Fork 343
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
101 changes: 101 additions & 0 deletions
101
mirai-api-http/src/main/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForward.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* | ||
* Copyright 2023 Mamoe Technologies and contributors. | ||
* | ||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. | ||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. | ||
* | ||
* https://github.com/mamoe/mirai/blob/master/LICENSE | ||
*/ | ||
|
||
package net.mamoe.mirai.api.http.adapter.http.plugin | ||
|
||
import io.ktor.http.* | ||
import io.ktor.server.application.* | ||
import io.ktor.server.request.* | ||
import io.ktor.util.* | ||
import io.ktor.util.pipeline.* | ||
import io.ktor.util.reflect.* | ||
import kotlinx.serialization.InternalSerializationApi | ||
import kotlinx.serialization.json.Json | ||
import kotlinx.serialization.json.JsonElement | ||
import kotlinx.serialization.serializer | ||
|
||
|
||
internal val HttpForwardAttributeKey = AttributeKey<HttpForwardContext>("HttpForward") | ||
internal val HttpForwardPhase = PipelinePhase("Forward") | ||
val HttpForward = createApplicationPlugin("HttpForward", ::HttpForwardConfig) { | ||
application.insertPhaseAfter(ApplicationCallPipeline.Call, HttpForwardPhase) | ||
|
||
application.intercept(HttpForwardPhase) { | ||
val forwardContext = call.attributes.getOrNull(HttpForwardAttributeKey) | ||
if (forwardContext != null && !forwardContext.forwarded) { | ||
forwardContext.forwarded = true | ||
forwardContext.convertors = this@createApplicationPlugin.pluginConfig.getConvertors() | ||
application.execute(ApplicationForwardCall(call, forwardContext)) | ||
} | ||
} | ||
} | ||
|
||
typealias BodyConvertor = (Any, TypeInfo) -> Any? | ||
|
||
class HttpForwardConfig { | ||
private val convertors: MutableList<BodyConvertor> = mutableListOf(DefaultBodyConvertor) | ||
|
||
fun addConvertor(convertor: BodyConvertor) { | ||
convertors.add(convertor) | ||
} | ||
|
||
fun getConvertors(): List<BodyConvertor> = convertors | ||
|
||
@OptIn(InternalSerializationApi::class) | ||
fun jsonElementBodyConvertor(json: Json) { | ||
addConvertor { body, typeInfo -> | ||
if (body is JsonElement) json.decodeFromJsonElement(typeInfo.type.serializer(), body) else null | ||
} | ||
} | ||
} | ||
|
||
val DefaultBodyConvertor: (Any, TypeInfo) -> Any? = { body, typeInfo -> | ||
if (typeInfo.type.isInstance(body)) body else null | ||
} | ||
|
||
internal data class HttpForwardContext(val router: String, val body: Any?) { | ||
var forwarded = false | ||
var convertors = emptyList<BodyConvertor>() | ||
} | ||
|
||
fun ApplicationCall.forward(forward: String) { | ||
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, null)) | ||
} | ||
|
||
fun ApplicationCall.forward(forward: String, body: Any) { | ||
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, body)) | ||
} | ||
|
||
internal fun forwardReceivePipeline(convertors: List<BodyConvertor>, body: Any): ApplicationReceivePipeline = | ||
ApplicationReceivePipeline().apply { | ||
intercept(ApplicationReceivePipeline.Transform) { | ||
proceedWith(convertors.firstNotNullOfOrNull { it.invoke(body, context.receiveType) } | ||
?: throw NoSuchElementException("fuck")) | ||
} | ||
} | ||
|
||
internal class ApplicationForwardCall( | ||
val delegate: ApplicationCall, val context: HttpForwardContext | ||
) : ApplicationCall by delegate { | ||
override val request: ApplicationRequest = DelegateApplicationRequest(this, context.router, context.body) | ||
} | ||
|
||
internal class DelegateApplicationRequest( | ||
override val call: ApplicationForwardCall, forward: String, body: Any? | ||
) : ApplicationRequest by call.delegate.request { | ||
private val _pipeline by lazy { | ||
body?.let { forwardReceivePipeline(call.context.convertors, it) } ?: call.delegate.request.pipeline | ||
} | ||
override val local = DelegateRequestConnectionPoint(call.delegate.request.local, forward) | ||
override val pipeline: ApplicationReceivePipeline = _pipeline | ||
} | ||
|
||
internal class DelegateRequestConnectionPoint( | ||
private val delegate: RequestConnectionPoint, override val uri: String | ||
) : RequestConnectionPoint by delegate |
170 changes: 170 additions & 0 deletions
170
...-api-http/src/test/kotlin/net/mamoe/mirai/api/http/adapter/http/plugin/HttpForwardTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
/* | ||
* Copyright 2023 Mamoe Technologies and contributors. | ||
* | ||
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. | ||
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. | ||
* | ||
* https://github.com/mamoe/mirai/blob/master/LICENSE | ||
*/ | ||
package net.mamoe.mirai.api.http.adapter.http.plugin | ||
|
||
import io.ktor.client.request.* | ||
import io.ktor.client.statement.* | ||
import io.ktor.http.* | ||
import io.ktor.serialization.kotlinx.json.* | ||
import io.ktor.server.application.* | ||
import io.ktor.server.plugins.contentnegotiation.* | ||
import io.ktor.server.plugins.doublereceive.* | ||
import io.ktor.server.request.* | ||
import io.ktor.server.response.* | ||
import io.ktor.server.routing.* | ||
import io.ktor.server.testing.* | ||
import kotlinx.serialization.Serializable | ||
import kotlinx.serialization.json.JsonElement | ||
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO | ||
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.NudgeDTO | ||
import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer | ||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
|
||
class HttpForwardTest { | ||
|
||
@Test | ||
fun testGetRequestForward() = testApplication { | ||
routing { | ||
get("/test") { | ||
call.forward("/forward") | ||
} | ||
|
||
get("/forward") { | ||
call.respondText(call.parameters["key"] ?: "null") | ||
} | ||
} | ||
|
||
client.get("/test") { | ||
parameter("key", "value") | ||
}.also { | ||
assertEquals(HttpStatusCode.OK, it.status) | ||
assertEquals("value", it.bodyAsText()) | ||
} | ||
} | ||
|
||
@Test | ||
fun testPostRequestForwardReceiveBody() = testApplication { | ||
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } | ||
|
||
routing { | ||
post("/test") { | ||
call.forward("/forward") | ||
} | ||
|
||
post("/forward") { | ||
val receive = call.receive<LongTargetDTO>() | ||
call.respondText(receive.target.toString()) | ||
} | ||
} | ||
|
||
client.post("/test") { | ||
contentType(ContentType.Application.Json) | ||
setBody("""{"target":123}""") | ||
}.also { | ||
assertEquals(HttpStatusCode.OK, it.status) | ||
assertEquals("123", it.bodyAsText()) | ||
} | ||
} | ||
|
||
@Test | ||
fun testPostRequestForwardDoubleReceiveBody() = testApplication { | ||
install(DoubleReceive) | ||
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } | ||
|
||
routing { | ||
post("/test") { | ||
val receive = call.receive<LongTargetDTO>() | ||
assertEquals(123, receive.target) | ||
call.forward("/forward") | ||
} | ||
|
||
post("/forward") { | ||
val receive = call.receive<LongTargetDTO>() | ||
call.respondText(receive.target.toString()) | ||
} | ||
} | ||
|
||
client.post("/test") { | ||
contentType(ContentType.Application.Json) | ||
setBody("""{"target":123}""") | ||
}.also { | ||
assertEquals(HttpStatusCode.OK, it.status) | ||
assertEquals("123", it.bodyAsText()) | ||
} | ||
} | ||
|
||
@Test | ||
fun testPostRequestForwardResetBody() = testApplication { | ||
install(DoubleReceive) | ||
install(HttpRouterMonitor) | ||
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) } | ||
|
||
routing { | ||
post("/test") { | ||
val receive = call.receive<LongTargetDTO>() | ||
assertEquals(123, receive.target) | ||
call.forward("/forward", NudgeDTO(321, 321, "kind")) | ||
} | ||
|
||
post("/forward") { | ||
val receive = call.receive<NudgeDTO>() | ||
call.respondText(receive.target.toString()) | ||
} | ||
} | ||
|
||
client.post("/test") { | ||
contentType(ContentType.Application.Json) | ||
setBody("""{"target":123}""") | ||
}.also { | ||
assertEquals(HttpStatusCode.OK, it.status) | ||
assertEquals("321", it.bodyAsText()) | ||
} | ||
} | ||
|
||
|
||
@Serializable | ||
private data class NestedDto( | ||
val router: String, | ||
val body: JsonElement, | ||
) | ||
|
||
@Test | ||
fun testPostRequestForwardNestedBody() = testApplication { | ||
val json = BuiltinJsonSerializer.buildJson() | ||
|
||
install(DoubleReceive) | ||
install(HttpRouterMonitor) | ||
install(ContentNegotiation) { json(json) } | ||
install(HttpForward) { jsonElementBodyConvertor(json) } | ||
|
||
routing { | ||
post("/test") { | ||
val receive = call.receive<NestedDto>() | ||
assertEquals("/forward", receive.router) | ||
call.forward("/forward", receive.body) | ||
|
||
call.respond(HttpStatusCode.OK) | ||
} | ||
|
||
post("/forward") { | ||
val receive = call.receive<LongTargetDTO>() | ||
call.respondText(receive.target.toString()) | ||
} | ||
} | ||
|
||
client.post("/test") { | ||
contentType(ContentType.Application.Json) | ||
setBody("""{"router":"/forward","body":{"target":321}}""") | ||
}.also { | ||
assertEquals(HttpStatusCode.OK, it.status) | ||
assertEquals("321", it.bodyAsText()) | ||
} | ||
} | ||
} |