Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ae02d1c

Browse files
committedMar 21, 2025·
Propagate CoroutineContext to WebClient filter
This commit introduces a new ResponseSpec.awaitEntityOrNull() extension function to replace ResponseSpec.toEntity(...).awaitFirstOrNull() and pass the CoroutineContext to the CoExchangeFilterFunction. CoroutineContext propagation is implemented via ReactorContext and ClientRequest attribute. See gh-32148 Signed-off-by: Dmitry Sulman <dmitry.sulman@gmail.com>
1 parent 5ce64f4 commit ae02d1c

File tree

4 files changed

+131
-7
lines changed

4 files changed

+131
-7
lines changed
 

‎spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java

+5
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
*/
7272
final class DefaultWebClient implements WebClient {
7373

74+
// Copy of CoExchangeFilterFunction.COROUTINE_CONTEXT_ATTRIBUTE value to avoid compilation errors in Eclipse
75+
private static final String COROUTINE_CONTEXT_ATTRIBUTE = "org.springframework.web.reactive.function.client.CoExchangeFilterFunction.context";
76+
7477
private static final String URI_TEMPLATE_ATTRIBUTE = WebClient.class.getName() + ".uriTemplate";
7578

7679
private static final Mono<ClientResponse> NO_HTTP_CLIENT_RESPONSE_ERROR = Mono.error(
@@ -430,6 +433,8 @@ private Mono<ClientResponse> exchange() {
430433
if (filterFunctions != null) {
431434
filterFunction = filterFunctions.andThen(filterFunction);
432435
}
436+
contextView.getOrEmpty(COROUTINE_CONTEXT_ATTRIBUTE)
437+
.ifPresent(context -> requestBuilder.attribute(COROUTINE_CONTEXT_ATTRIBUTE, context));
433438
ClientRequest request = requestBuilder.build();
434439
observationContext.setUriTemplate((String) request.attribute(URI_TEMPLATE_ATTRIBUTE).orElse(null));
435440
observationContext.setRequest(request);

‎spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,9 +17,13 @@
1717
package org.springframework.web.reactive.function.client
1818

1919
import kotlinx.coroutines.Dispatchers
20+
import kotlinx.coroutines.Job
21+
import kotlinx.coroutines.currentCoroutineContext
2022
import kotlinx.coroutines.reactor.awaitSingle
2123
import kotlinx.coroutines.reactor.mono
2224
import reactor.core.publisher.Mono
25+
import kotlin.coroutines.CoroutineContext
26+
import kotlin.jvm.optionals.getOrNull
2327

2428
/**
2529
* Kotlin-specific implementation of the [ExchangeFilterFunction] interface
@@ -31,10 +35,14 @@ import reactor.core.publisher.Mono
3135
abstract class CoExchangeFilterFunction : ExchangeFilterFunction {
3236

3337
final override fun filter(request: ClientRequest, next: ExchangeFunction): Mono<ClientResponse> {
34-
return mono(Dispatchers.Unconfined) {
38+
val context = request.attribute(COROUTINE_CONTEXT_ATTRIBUTE).getOrNull() as CoroutineContext?
39+
return mono(context ?: Dispatchers.Unconfined) {
3540
filter(request, object : CoExchangeFunction {
3641
override suspend fun exchange(request: ClientRequest): ClientResponse {
37-
return next.exchange(request).awaitSingle()
42+
val newRequest = ClientRequest.from(request)
43+
.attribute(COROUTINE_CONTEXT_ATTRIBUTE, currentCoroutineContext().minusKey(Job.Key))
44+
.build()
45+
return next.exchange(newRequest).awaitSingle()
3846
}
3947
})
4048
}
@@ -58,6 +66,17 @@ abstract class CoExchangeFilterFunction : ExchangeFilterFunction {
5866
* @return the filtered response
5967
*/
6068
protected abstract suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse
69+
70+
companion object {
71+
72+
/**
73+
* Name of the [ClientRequest] attribute that contains the
74+
* [kotlin.coroutines.CoroutineContext] to be passed to the
75+
* [CoExchangeFilterFunction.filter].
76+
*/
77+
@JvmField
78+
val COROUTINE_CONTEXT_ATTRIBUTE = CoExchangeFilterFunction::class.java.name + ".context"
79+
}
6180
}
6281

6382

‎spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt

+19-4
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ import kotlinx.coroutines.Job
2020
import kotlinx.coroutines.currentCoroutineContext
2121
import kotlinx.coroutines.flow.Flow
2222
import kotlinx.coroutines.reactive.asFlow
23-
import kotlinx.coroutines.reactor.asFlux
24-
import kotlinx.coroutines.reactor.awaitSingle
25-
import kotlinx.coroutines.reactor.awaitSingleOrNull
26-
import kotlinx.coroutines.reactor.mono
23+
import kotlinx.coroutines.reactor.*
24+
import kotlinx.coroutines.withContext
2725
import org.reactivestreams.Publisher
2826
import org.springframework.core.ParameterizedTypeReference
2927
import org.springframework.http.ResponseEntity
28+
import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE
3029
import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec
3130
import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec
3231
import reactor.core.publisher.Flux
3332
import reactor.core.publisher.Mono
33+
import reactor.util.context.Context
3434

3535
/**
3636
* Extension for [WebClient.RequestBodySpec.body] providing a `body(Publisher<T>)` variant
@@ -203,3 +203,18 @@ inline fun <reified T : Any> WebClient.ResponseSpec.toEntityList(): Mono<Respons
203203
*/
204204
inline fun <reified T : Any> WebClient.ResponseSpec.toEntityFlux(): Mono<ResponseEntity<Flux<T>>> =
205205
toEntityFlux(object : ParameterizedTypeReference<T>() {})
206+
207+
/**
208+
* Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity<Foo>()` variant
209+
* leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext]
210+
* propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure
211+
* and retains actual generic type arguments.
212+
*
213+
* @since 7.0.0
214+
*/
215+
suspend inline fun <reified T : Any> WebClient.ResponseSpec.awaitEntityOrNull(): ResponseEntity<T>? {
216+
val coroutineContext = currentCoroutineContext().minusKey(Job.Key).minusKey(ReactorContext.Key)
217+
val reactorContext = currentCoroutineContext()[ReactorContext.Key]?.context ?: Context.empty()
218+
val newReactorContext = reactorContext.put(COROUTINE_CONTEXT_ATTRIBUTE, coroutineContext)
219+
return withContext(newReactorContext.asCoroutineContext()) { toEntity(T::class.java).awaitSingleOrNull() }
220+
}

‎spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt

+85
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@ import kotlinx.coroutines.flow.Flow
2525
import kotlinx.coroutines.flow.flow
2626
import kotlinx.coroutines.flow.toList
2727
import kotlinx.coroutines.runBlocking
28+
import kotlinx.coroutines.withContext
2829
import org.assertj.core.api.Assertions.assertThat
2930
import org.junit.jupiter.api.Test
3031
import org.reactivestreams.Publisher
3132
import org.springframework.core.ParameterizedTypeReference
33+
import org.springframework.http.HttpHeaders
34+
import org.springframework.http.HttpStatus
35+
import org.springframework.http.MediaType
3236
import org.springframework.http.ResponseEntity
37+
import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE
3338
import reactor.core.publisher.Flux
3439
import reactor.core.publisher.Mono
40+
import java.util.*
3541
import java.util.concurrent.CompletableFuture
3642
import java.util.function.Function
3743
import kotlin.coroutines.AbstractCoroutineContextElement
@@ -226,9 +232,88 @@ class WebClientExtensionsTests {
226232
verify { responseSpec.toEntityFlux(object : ParameterizedTypeReference<List<Foo>>() {}) }
227233
}
228234

235+
@Test
236+
fun `ResponseSpec#awaitEntityOrNull with coroutine context propagation`() {
237+
val exchangeFunction = mockk<ExchangeFunction>()
238+
val mockResponse = mockk<ClientResponse>()
239+
val foo = mockk<Foo>()
240+
val slot = slot<ClientRequest>()
241+
every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse)
242+
every { mockResponse.statusCode() } returns HttpStatus.OK
243+
every { mockResponse.headers() } returns MockClientHeaders()
244+
every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo)
245+
runBlocking {
246+
withContext(FooContextElement(foo)) {
247+
val responseEntity = WebClient.builder()
248+
.exchangeFunction(exchangeFunction)
249+
.filter(object : CoExchangeFilterFunction() {
250+
override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse {
251+
assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo)
252+
return next.exchange(request)
253+
}
254+
})
255+
.build().get().uri("/path").retrieve().awaitEntityOrNull<Foo>()
256+
val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext
257+
assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo)
258+
assertThat(responseEntity!!.body).isEqualTo(foo)
259+
}
260+
}
261+
}
262+
263+
@Test
264+
fun `ResponseSpec#awaitEntityOrNull with coroutine context propagation to multiple CoExchangeFilterFunctions`() {
265+
val exchangeFunction = mockk<ExchangeFunction>()
266+
val mockResponse = mockk<ClientResponse>()
267+
val foo = mockk<Foo>()
268+
val slot = slot<ClientRequest>()
269+
every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse)
270+
every { mockResponse.statusCode() } returns HttpStatus.OK
271+
every { mockResponse.headers() } returns MockClientHeaders()
272+
every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo)
273+
runBlocking {
274+
val responseEntity = WebClient.builder()
275+
.exchangeFunction(exchangeFunction)
276+
.filter(object : CoExchangeFilterFunction() {
277+
override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse {
278+
return withContext(FooContextElement(foo)) { next.exchange(request) }
279+
}
280+
})
281+
.filter(object : CoExchangeFilterFunction() {
282+
override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse {
283+
assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo)
284+
return next.exchange(request)
285+
}
286+
})
287+
.build().get().uri("/path").retrieve().awaitEntityOrNull<Foo>()
288+
val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext
289+
assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo)
290+
assertThat(responseEntity!!.body).isEqualTo(foo)
291+
}
292+
}
293+
229294
class Foo
230295

231296
private data class FooContextElement(val foo: Foo) : AbstractCoroutineContextElement(FooContextElement) {
232297
companion object Key : CoroutineContext.Key<FooContextElement>
233298
}
299+
300+
private class MockClientHeaders : ClientResponse.Headers {
301+
private val headers = HttpHeaders()
302+
303+
override fun contentLength(): OptionalLong {
304+
return OptionalLong.empty()
305+
}
306+
307+
override fun contentType(): Optional<MediaType> {
308+
return Optional.empty()
309+
}
310+
311+
override fun header(headerName: String): List<String> {
312+
return emptyList()
313+
}
314+
315+
override fun asHttpHeaders(): HttpHeaders {
316+
return headers
317+
}
318+
}
234319
}

0 commit comments

Comments
 (0)
Please sign in to comment.