Skip to content

Commit

Permalink
Support http forward
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoii committed Oct 10, 2023
1 parent 0a3e454 commit fc6b79f
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/master/LICENSE
*/

package net.mamoe.mirai.api.http.adapter.http.plugin

import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.plugins.*
import io.ktor.server.request.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.util.reflect.*
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.serializer


internal val HttpForwardAttributeKey = AttributeKey<HttpForwardContext>("HttpForward")
internal val HttpForwardPhase = PipelinePhase("Forward")
val HttpForward = createApplicationPlugin("HttpForward", ::HttpForwardConfig) {
application.insertPhaseAfter(ApplicationCallPipeline.Call, HttpForwardPhase)

application.intercept(HttpForwardPhase) {
val forwardContext = call.attributes.getOrNull(HttpForwardAttributeKey)
if (forwardContext != null && !forwardContext.forwarded) {
forwardContext.forwarded = true
forwardContext.convertors = this@createApplicationPlugin.pluginConfig.getConvertors()
application.execute(ApplicationForwardCall(call, forwardContext))
}
}
}

typealias BodyConvertor = (Any, TypeInfo) -> Any?

class HttpForwardConfig {
private val convertors: MutableList<BodyConvertor> = mutableListOf(DefaultBodyConvertor)

fun addConvertor(convertor: BodyConvertor) {
convertors.add(convertor)
}

fun getConvertors(): List<BodyConvertor> = convertors

@OptIn(InternalSerializationApi::class)
fun jsonElementBodyConvertor(json: Json) {
addConvertor { body, typeInfo ->
val b = if (body == NullBody) JsonNull else body
when {
b !is JsonElement -> null
typeInfo.type == String::class -> json.encodeToString(b)
else -> json.decodeFromJsonElement(typeInfo.type.serializer(), b)
}
}
}
}

val DefaultBodyConvertor: (Any, TypeInfo) -> Any? = { body, typeInfo ->
if (typeInfo.type.isInstance(body)) body else null
}

internal data class HttpForwardContext(val router: String, val body: Any?) {
var forwarded = false
var convertors = emptyList<BodyConvertor>()
}

fun ApplicationCall.forward(forward: String) {
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, null))
}

fun ApplicationCall.forward(forward: String, body: Any?) {
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, body ?: NullBody))
}

internal fun forwardReceivePipeline(convertors: List<BodyConvertor>, body: Any): ApplicationReceivePipeline =
ApplicationReceivePipeline().apply {
intercept(ApplicationReceivePipeline.Transform) {
proceedWith(convertors.firstNotNullOfOrNull { it.invoke(body, context.receiveType) }
?: throw CannotTransformContentToTypeException(context.receiveType.kotlinType!!))
}
}

internal class ApplicationForwardCall(
val delegate: ApplicationCall, val context: HttpForwardContext
) : ApplicationCall by delegate {
override val request: ApplicationRequest = DelegateApplicationRequest(this, context.router, context.body)
}

internal class DelegateApplicationRequest(
override val call: ApplicationForwardCall, forward: String, body: Any?
) : ApplicationRequest by call.delegate.request {
private val _pipeline by lazy {
body?.let { forwardReceivePipeline(call.context.convertors, it) } ?: call.delegate.request.pipeline
}
override val local = DelegateRequestConnectionPoint(call.delegate.request.local, forward)
override val pipeline: ApplicationReceivePipeline = _pipeline
}

internal class DelegateRequestConnectionPoint(
private val delegate: RequestConnectionPoint, override val uri: String
) : RequestConnectionPoint by delegate
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright 2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/master/LICENSE
*/
package net.mamoe.mirai.api.http.adapter.http.plugin

import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.server.plugins.doublereceive.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.NudgeDTO
import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer
import kotlin.test.Test
import kotlin.test.assertEquals

class HttpForwardTest {

@Test
fun testGetRequestForward() = testApplication {
routing {
get("/test") {
call.forward("/forward")
}

get("/forward") {
call.respondText(call.parameters["key"] ?: "null")
}
}

client.get("/test") {
parameter("key", "value")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("value", it.bodyAsText())
}
}

@Test
fun testPostRequestForwardReceiveBody() = testApplication {
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }

routing {
post("/test") {
call.forward("/forward")
}

post("/forward") {
val receive = call.receive<LongTargetDTO>()
call.respondText(receive.target.toString())
}
}

client.post("/test") {
contentType(ContentType.Application.Json)
setBody("""{"target":123}""")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("123", it.bodyAsText())
}
}

@Test
fun testPostRequestForwardDoubleReceiveBody() = testApplication {
install(DoubleReceive)
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }

routing {
post("/test") {
val receive = call.receive<LongTargetDTO>()
assertEquals(123, receive.target)
call.forward("/forward")
}

post("/forward") {
val receive = call.receive<LongTargetDTO>()
call.respondText(receive.target.toString())
}
}

client.post("/test") {
contentType(ContentType.Application.Json)
setBody("""{"target":123}""")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("123", it.bodyAsText())
}
}

@Test
fun testPostRequestForwardResetBody() = testApplication {
install(DoubleReceive)
install(HttpRouterMonitor)
install(ContentNegotiation) { json(json=BuiltinJsonSerializer.buildJson()) }

routing {
post("/test") {
val receive = call.receive<LongTargetDTO>()
assertEquals(123, receive.target)
call.forward("/forward", NudgeDTO(321, 321, "kind"))
}

post("/forward") {
val receive = call.receive<NudgeDTO>()
call.respondText(receive.target.toString())
}
}

client.post("/test") {
contentType(ContentType.Application.Json)
setBody("""{"target":123}""")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("321", it.bodyAsText())
}
}


@Serializable
private data class NestedDto(
val router: String,
val body: JsonElement,
)

@Test
fun testPostRequestForwardNestedBody() = testApplication {
val json = BuiltinJsonSerializer.buildJson()

install(DoubleReceive)
install(HttpRouterMonitor)
install(ContentNegotiation) { json(json) }
install(HttpForward) { jsonElementBodyConvertor(json) }

routing {
post("/test") {
val receive = call.receive<NestedDto>()
assertEquals("/forward", receive.router)
call.forward("/forward", receive.body)

call.respond(HttpStatusCode.OK)
}

post("/forward") {
val receive = call.receive<LongTargetDTO>()
call.respondText(receive.target.toString())
}
}

client.post("/test") {
contentType(ContentType.Application.Json)
setBody("""{"router":"/forward","body":{"target":321}}""")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("321", it.bodyAsText())
}
}
}

0 comments on commit fc6b79f

Please sign in to comment.