Skip to content

Commit

Permalink
KTOR-2822 Change HttpReceivePipeline context to Unit (#2522)
Browse files Browse the repository at this point in the history
Otherwise it's impossible to modify response, since context has its own instance of response.
  • Loading branch information
rsinukov authored and e5l committed Oct 4, 2021
1 parent 3c37f30 commit 7d6faea
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ public class HttpClient(

sendPipeline.intercept(HttpSendPipeline.Receive) { call ->
check(call is HttpClientCall) { "Error: HttpClientCall expected, but found $call(${call::class})." }
val receivedCall = receivePipeline.execute(call, call.response).call
proceedWith(receivedCall)
val response = receivePipeline.execute(Unit, call.response)
call.response = response
proceedWith(call)
}

with(userConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.ktor.client.plugins

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.content.*
import io.ktor.client.plugins.observer.*
import io.ktor.client.request.*
Expand Down Expand Up @@ -42,14 +41,10 @@ public class BodyProgress internal constructor() {
}

scope.receivePipeline.intercept(HttpReceivePipeline.After) { response ->
val listener = context.request.attributes
val listener = response.call.request.attributes
.getOrNull(DownloadProgressListenerAttributeKey) ?: return@intercept
val observableCall = context.withObservableDownload(listener)

context.response = observableCall.response
context.request = observableCall.request

proceedWith(context.response)
val observableResponse = response.withObservableDownload(listener)
proceedWith(observableResponse)
}
}

Expand All @@ -66,8 +61,8 @@ public class BodyProgress internal constructor() {
}
}

internal fun HttpClientCall.withObservableDownload(listener: ProgressListener): HttpClientCall {
val observableByteChannel = response.content.observable(coroutineContext, response.contentLength(), listener)
internal fun HttpResponse.withObservableDownload(listener: ProgressListener): HttpResponse {
val observableByteChannel = content.observable(coroutineContext, contentLength(), listener)
return wrapWithContent(observableByteChannel)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public class HttpCache(
}

scope.receivePipeline.intercept(HttpReceivePipeline.State) { response ->
if (context.request.method != HttpMethod.Get) return@intercept
if (response.call.request.method != HttpMethod.Get) return@intercept

if (response.status.isSuccess()) {
val reusableResponse = plugin.cacheResponse(response)
Expand All @@ -105,8 +105,8 @@ public class HttpCache(

if (response.status == HttpStatusCode.NotModified) {
response.complete()
val responseFromCache = plugin.findAndRefresh(context.request, response)
?: throw InvalidCacheStateException(context.request.url)
val responseFromCache = plugin.findAndRefresh(response.call.request, response)
?: throw InvalidCacheStateException(response.call.request.url)

scope.monitor.raise(HttpResponseFromCache, responseFromCache)
proceedWith(responseFromCache)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public fun HttpClientCall.wrapWithContent(content: ByteReadChannel): HttpClientC
return DelegatedCall(currentClient, content, this)
}

/**
* Wrap existing [HttpResponse] with new [content].
*/
internal fun HttpResponse.wrapWithContent(content: ByteReadChannel): HttpResponse {
return DelegatedResponse(call, content, this)
}

internal class DelegatedCall(
client: HttpClient,
content: ByteReadChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,22 @@ public class ResponseObserver(
scope.receivePipeline.intercept(HttpReceivePipeline.After) { response ->
val (loggingContent, responseContent) = response.content.split(response)

val newClientCall = context.wrapWithContent(responseContent)
val sideCall = newClientCall.wrapWithContent(loggingContent)
val newResponse = response.wrapWithContent(responseContent)
val sideResponse = response.call.wrapWithContent(loggingContent).response

scope.launch {
try {
plugin.responseHandler(sideCall.response)
plugin.responseHandler(sideResponse)
} catch (_: Throwable) {
}

val content = sideCall.response.content
val content = sideResponse.content
if (!content.isClosedForRead) {
content.discard()
}
}

context.response = newClientCall.response
context.request = newClientCall.request
proceedWith(context.response)
proceedWith(newResponse)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class HttpResponsePipeline(
*/
public class HttpReceivePipeline(
override val developmentMode: Boolean = false
) : Pipeline<HttpResponse, HttpClientCall>(Before, State, After) {
) : Pipeline<HttpResponse, Unit>(Before, State, After) {
public companion object Phases {
/**
* The earliest phase that happens before any other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ public class Logging(
}
}

private fun logResponseException(context: HttpClientCall, cause: Throwable) {
private fun logResponseException(request: HttpRequest, cause: Throwable) {
if (level.info) {
logger.log("RESPONSE ${context.request.url} failed with exception: $cause")
logger.log("RESPONSE ${request.url} failed with exception: $cause")
}
}

Expand Down Expand Up @@ -181,13 +181,13 @@ public class Logging(
}
}

scope.receivePipeline.intercept(HttpReceivePipeline.State) {
scope.receivePipeline.intercept(HttpReceivePipeline.State) { response ->
try {
plugin.beginLogging()
plugin.logResponse(context.response)
plugin.logResponse(response.call.response)
proceedWith(subject)
} catch (cause: Throwable) {
plugin.logResponseException(context, cause)
plugin.logResponseException(response.call.request, cause)
throw cause
} finally {
if (!plugin.level.body) {
Expand All @@ -200,7 +200,7 @@ public class Logging(
try {
proceed()
} catch (cause: Throwable) {
plugin.logResponseException(context, cause)
plugin.logResponseException(context.request, cause)
throw cause
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.tests

import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.client.tests.utils.*
import io.ktor.client.utils.*
import io.ktor.http.*
import io.ktor.util.date.*
import io.ktor.utils.io.*
import kotlin.coroutines.*
import kotlin.test.*

class ClientPipelinesTest : ClientLoader() {
@Test
fun testCanAddHeaders() = clientTests {
config {
install("attr-test") {
receivePipeline.intercept(HttpReceivePipeline.State) { response ->
val headers = buildHeaders {
appendAll(response.headers)
append(HttpHeaders.WWWAuthenticate, "Bearer")
}
proceedWith(
object : HttpResponse() {
override val call: HttpClientCall = response.call
override val status: HttpStatusCode = response.status
override val version: HttpProtocolVersion = response.version
override val requestTime: GMTDate = response.requestTime
override val responseTime: GMTDate = response.responseTime
override val content: ByteReadChannel = response.content
override val headers get() = headers
override val coroutineContext: CoroutineContext = response.coroutineContext
}
)
}
}
}

test { client ->
val response = client.get("$TEST_SERVER/content/hello")
assertEquals("Bearer", response.headers[HttpHeaders.WWWAuthenticate])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PluginsTest : ClientLoader() {
val task = Job()
config {
ResponseObserver { response ->
val text = response.body<String>()
val text = response.content.readRemaining(Long.MAX_VALUE, 0).readText()
assertEquals(body, text)
task.complete()
}
Expand Down

0 comments on commit 7d6faea

Please sign in to comment.