diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java index c9ebda761307..eaf5ea87a8b5 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.function.Supplier; @@ -63,6 +64,7 @@ * @author Sebastien Deleuze * @author Olga Maciaszek-Sharma * @author Sam Brannen + * @author Mengqi Xu * @since 6.0 */ final class HttpServiceMethod { @@ -412,7 +414,29 @@ public static ResponseFunction create(HttpExchangeAdapter client, Method method) "Kotlin Coroutines are only supported with reactive implementations"); } - MethodParameter param = new MethodParameter(method, -1).nestedIfOptional(); + MethodParameter param = new MethodParameter(method, -1); + Class paramType = param.getNestedParameterType(); + + Function responseFunction; + if (paramType.equals(CompletableFuture.class)) { + MethodParameter bodyParam = param.nested(); + MethodParameter nestedParamIfOptional = bodyParam.getNestedParameterType().equals(Optional.class) ? + bodyParam.nested() : bodyParam; + responseFunction = request -> + CompletableFuture.supplyAsync(() -> + asOptionalIfNecessary(buildResponseFunction(client, nestedParamIfOptional).apply(request), + bodyParam.getNestedParameterType())); + } + else { + responseFunction = request -> + asOptionalIfNecessary(buildResponseFunction(client, param.nestedIfOptional()).apply(request), + param.getParameterType()); + } + + return new ExchangeResponseFunction(responseFunction); + } + + private static Function buildResponseFunction(HttpExchangeAdapter client, MethodParameter param) { Class paramType = param.getNestedParameterType(); Function responseFunction; @@ -423,33 +447,30 @@ public static ResponseFunction create(HttpExchangeAdapter client, Method method) }; } else if (paramType.equals(HttpHeaders.class)) { - responseFunction = request -> asOptionalIfNecessary(client.exchangeForHeaders(request), param); + responseFunction = client::exchangeForHeaders; } else if (paramType.equals(ResponseEntity.class)) { MethodParameter bodyParam = param.nested(); if (bodyParam.getNestedParameterType().equals(Void.class)) { - responseFunction = request -> - asOptionalIfNecessary(client.exchangeForBodilessEntity(request), param); + responseFunction = client::exchangeForBodilessEntity; } else { ParameterizedTypeReference bodyTypeRef = ParameterizedTypeReference.forType(bodyParam.getNestedGenericParameterType()); - responseFunction = request -> - asOptionalIfNecessary(client.exchangeForEntity(request, bodyTypeRef), param); + responseFunction = request -> client.exchangeForEntity(request, bodyTypeRef); } } else { ParameterizedTypeReference bodyTypeRef = ParameterizedTypeReference.forType(param.getNestedGenericParameterType()); - responseFunction = request -> - asOptionalIfNecessary(client.exchangeForBody(request, bodyTypeRef), param); + responseFunction = request -> client.exchangeForBody(request, bodyTypeRef); } - return new ExchangeResponseFunction(responseFunction); + return responseFunction; } - private static @Nullable Object asOptionalIfNecessary(@Nullable Object response, MethodParameter param) { - return param.getParameterType().equals(Optional.class) ? Optional.ofNullable(response) : response; + private static @Nullable Object asOptionalIfNecessary(@Nullable Object response, Class type) { + return type.equals(Optional.class) ? Optional.ofNullable(response) : response; } } diff --git a/spring-web/src/test/java/org/springframework/web/service/invoker/HttpServiceMethodTests.java b/spring-web/src/test/java/org/springframework/web/service/invoker/HttpServiceMethodTests.java index dabecb110b97..6dd001ff664f 100644 --- a/spring-web/src/test/java/org/springframework/web/service/invoker/HttpServiceMethodTests.java +++ b/spring-web/src/test/java/org/springframework/web/service/invoker/HttpServiceMethodTests.java @@ -23,6 +23,8 @@ import java.lang.reflect.Method; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; @@ -61,6 +63,7 @@ * @author Rossen Stoyanchev * @author Olga Maciaszek-Sharma * @author Sam Brannen + * @author Mengqi Xu */ class HttpServiceMethodTests { @@ -103,6 +106,34 @@ void service() { assertThat(list).containsOnly("exchangeForBody"); } + @Test // gh-34748 + void completableFutureService() throws ExecutionException, InterruptedException { + CompletableFutureService service = this.proxyFactory.createClient(CompletableFutureService.class); + + service.execute(); + + HttpHeaders headers = service.getHeaders().get(); + assertThat(headers).isNotNull(); + + String body = service.getBody().get(); + assertThat(body).isEqualTo(this.client.getInvokedMethodName()); + + Optional optional = service.getBodyOptional().get(); + assertThat(optional.get()).isEqualTo("exchangeForBody"); + + ResponseEntity entity = service.getEntity().get(); + assertThat(entity.getBody()).isEqualTo("exchangeForEntity"); + + Optional> entityOptional = service.getEntityOptional().get(); + assertThat(entityOptional.get().getBody()).isEqualTo("exchangeForEntity"); + + ResponseEntity voidEntity = service.getVoidEntity().get(); + assertThat(voidEntity.getBody()).isNull(); + + List list = service.getList().get(); + assertThat(list).containsOnly("exchangeForBody"); + } + @Test void reactorService() { ReactorService service = this.reactorProxyFactory.createClient(ReactorService.class); @@ -294,6 +325,35 @@ private interface Service { } + @SuppressWarnings("unused") + private interface CompletableFutureService { + + @GetExchange + CompletableFuture execute(); + + @GetExchange + CompletableFuture getHeaders(); + + @GetExchange + CompletableFuture getBody(); + + @GetExchange + CompletableFuture> getBodyOptional(); + + @GetExchange + CompletableFuture> getVoidEntity(); + + @GetExchange + CompletableFuture> getEntity(); + + @GetExchange + CompletableFuture>> getEntityOptional(); + + @GetExchange + CompletableFuture> getList(); + + } + private interface ReactorService {