Skip to content

Commit

Permalink
Refactor implementation of retrieve in RestClient
Browse files Browse the repository at this point in the history
Closes gh-33777
  • Loading branch information
rstoyanchev committed Oct 23, 2024
1 parent 8fa99dc commit bff76d7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ public Builder mutate() {

@Nullable
@SuppressWarnings({"rawtypes", "unchecked"})
private <T> T readWithMessageConverters(ClientHttpResponse clientResponse, Runnable callback, Type bodyType,
Class<T> bodyClass, @Nullable Observation observation) {
private <T> T readWithMessageConverters(
ClientHttpResponse clientResponse, Runnable callback, Type bodyType, Class<T> bodyClass) {

MediaType contentType = getContentType(clientResponse);

Expand Down Expand Up @@ -257,21 +257,8 @@ else if (messageConverter.canRead(bodyClass, contentType)) {
else {
cause = exc;
}
RestClientException restClientException = new RestClientException("Error while extracting response for type [" +
throw new RestClientException("Error while extracting response for type [" +
ResolvableType.forType(bodyType) + "] and content type [" + contentType + "]", cause);
if (observation != null) {
observation.error(restClientException);
}
throw restClientException;
}
catch (RestClientException restClientException) {
if (observation != null) {
observation.error(restClientException);
}
throw restClientException;
}
finally {
clientResponse.close();
}
}

Expand Down Expand Up @@ -536,14 +523,16 @@ private void logBody(Object body, @Nullable MediaType mediaType, HttpMessageConv

@Override
public ResponseSpec retrieve() {
return exchangeInternal(DefaultResponseSpec::new, false);
return new DefaultResponseSpec(this);
}

@Override
@Nullable
public <T> T exchange(ExchangeFunction<T> exchangeFunction, boolean close) {
return exchangeInternal(exchangeFunction, close);
}

@Nullable
private <T> T exchangeInternal(ExchangeFunction<T> exchangeFunction, boolean close) {
Assert.notNull(exchangeFunction, "ExchangeFunction must not be null");

Expand Down Expand Up @@ -578,39 +567,31 @@ private <T> T exchangeInternal(ExchangeFunction<T> exchangeFunction, boolean clo
}
clientResponse = clientRequest.execute();
observationContext.setResponse(clientResponse);
ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse, observation, observationScope);
ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse);
return exchangeFunction.exchange(clientRequest, convertibleWrapper);
}
catch (IOException ex) {
ResourceAccessException resourceAccessException = createResourceAccessException(uri, this.httpMethod, ex);
if (observationScope != null) {
observationScope.close();
}
if (observation != null) {
observation.error(resourceAccessException);
observation.stop();
}
throw resourceAccessException;
}
catch (Throwable error) {
if (observationScope != null) {
observationScope.close();
}
if (observation != null) {
observation.error(error);
observation.stop();
}
throw error;
}
finally {
if (observationScope != null) {
observationScope.close();
}
if (observation != null) {
observation.stop();
}
if (close && clientResponse != null) {
clientResponse.close();
if (observationScope != null) {
observationScope.close();
}
if (observation != null) {
observation.stop();
}
}
}
}
Expand Down Expand Up @@ -719,17 +700,14 @@ private interface InternalBody {

private class DefaultResponseSpec implements ResponseSpec {

private final HttpRequest clientRequest;

private final ClientHttpResponse clientResponse;
private final RequestHeadersSpec<?> requestHeadersSpec;

private final List<StatusHandler> statusHandlers = new ArrayList<>(1);

private final int defaultStatusHandlerCount;

DefaultResponseSpec(HttpRequest clientRequest, ClientHttpResponse clientResponse) {
this.clientRequest = clientRequest;
this.clientResponse = clientResponse;
DefaultResponseSpec(RequestHeadersSpec<?> requestHeadersSpec) {
this.requestHeadersSpec = requestHeadersSpec;
this.statusHandlers.addAll(DefaultRestClient.this.defaultStatusHandlers);
this.statusHandlers.add(StatusHandler.defaultHandler(DefaultRestClient.this.messageConverters));
this.defaultStatusHandlerCount = this.statusHandlers.size();
Expand Down Expand Up @@ -761,15 +739,15 @@ private ResponseSpec onStatusInternal(StatusHandler statusHandler) {
@Override
@Nullable
public <T> T body(Class<T> bodyType) {
return readBody(bodyType, bodyType);
return executeAndExtract((request, response) -> readBody(request, response, bodyType, bodyType));
}

@Override
@Nullable
public <T> T body(ParameterizedTypeReference<T> bodyType) {
Type type = bodyType.getType();
Class<T> bodyClass = bodyClass(type);
return readBody(type, bodyClass);
return executeAndExtract((request, response) -> readBody(request, response, type, bodyClass));
}

@Override
Expand All @@ -785,50 +763,64 @@ public <T> ResponseEntity<T> toEntity(ParameterizedTypeReference<T> bodyType) {
}

private <T> ResponseEntity<T> toEntityInternal(Type bodyType, Class<T> bodyClass) {
T body = readBody(bodyType, bodyClass);
try {
return ResponseEntity.status(this.clientResponse.getStatusCode())
.headers(this.clientResponse.getHeaders())
.body(body);
}
catch (IOException ex) {
throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex);
}
ResponseEntity<T> entity = executeAndExtract((request, response) -> {
T body = readBody(request, response, bodyType, bodyClass);
try {
return ResponseEntity.status(response.getStatusCode())
.headers(response.getHeaders())
.body(body);
}
catch (IOException ex) {
throw new ResourceAccessException(
"Could not retrieve response status code: " + ex.getMessage(), ex);
}
});
Assert.state(entity != null, "No ResponseEntity");
return entity;
}

@Override
public ResponseEntity<Void> toBodilessEntity() {
try (this.clientResponse) {
applyStatusHandlers();
return ResponseEntity.status(this.clientResponse.getStatusCode())
.headers(this.clientResponse.getHeaders())
.build();
}
catch (UncheckedIOException ex) {
throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex.getCause());
}
catch (IOException ex) {
throw new ResourceAccessException("Could not retrieve response status code: " + ex.getMessage(), ex);
}
ResponseEntity<Void> entity = executeAndExtract((request, response) -> {
try (response) {
applyStatusHandlers(request, response);
return ResponseEntity.status(response.getStatusCode())
.headers(response.getHeaders())
.build();
}
catch (UncheckedIOException ex) {
throw new ResourceAccessException(
"Could not retrieve response status code: " + ex.getMessage(), ex.getCause());
}
catch (IOException ex) {
throw new ResourceAccessException(
"Could not retrieve response status code: " + ex.getMessage(), ex);
}
});
Assert.state(entity != null, "No ResponseEntity");
return entity;
}

@Nullable
public <T> T executeAndExtract(RequestHeadersSpec.ExchangeFunction<T> exchangeFunction) {
return this.requestHeadersSpec.exchange(exchangeFunction);
}

@Nullable
private <T> T readBody(Type bodyType, Class<T> bodyClass) {
return DefaultRestClient.this.readWithMessageConverters(this.clientResponse, this::applyStatusHandlers,
bodyType, bodyClass, getCurrentObservation());
private <T> T readBody(HttpRequest request, ClientHttpResponse response, Type bodyType, Class<T> bodyClass) {
return DefaultRestClient.this.readWithMessageConverters(
response, () -> applyStatusHandlers(request, response), bodyType, bodyClass);

}

private void applyStatusHandlers() {
private void applyStatusHandlers(HttpRequest request, ClientHttpResponse response) {
try {
ClientHttpResponse response = this.clientResponse;
if (response instanceof DefaultConvertibleClientHttpResponse convertibleResponse) {
response = convertibleResponse.delegate;
}
for (StatusHandler handler : this.statusHandlers) {
if (handler.test(response)) {
handler.handle(this.clientRequest, response);
handler.handle(request, response);
return;
}
}
Expand All @@ -838,44 +830,29 @@ private void applyStatusHandlers() {
}
}

@Nullable
private Observation getCurrentObservation() {
if (this.clientResponse instanceof DefaultConvertibleClientHttpResponse convertibleResponse) {
return convertibleResponse.observation;
}
return null;
}

}


private class DefaultConvertibleClientHttpResponse implements RequestHeadersSpec.ConvertibleClientHttpResponse {

private final ClientHttpResponse delegate;

private final Observation observation;

private final Observation.Scope observationScope;

public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate, Observation observation, Observation.Scope observationScope) {
public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate) {
this.delegate = delegate;
this.observation = observation;
this.observationScope = observationScope;
}


@Nullable
@Override
public <T> T bodyTo(Class<T> bodyType) {
return readWithMessageConverters(this.delegate, () -> {} , bodyType, bodyType, this.observation);
return readWithMessageConverters(this.delegate, () -> {} , bodyType, bodyType);
}

@Nullable
@Override
public <T> T bodyTo(ParameterizedTypeReference<T> bodyType) {
Type type = bodyType.getType();
Class<T> bodyClass = bodyClass(type);
return readWithMessageConverters(this.delegate, () -> {}, type, bodyClass, this.observation);
return readWithMessageConverters(this.delegate, () -> {}, type, bodyClass);
}

@Override
Expand All @@ -901,8 +878,6 @@ public String getStatusText() throws IOException {
@Override
public void close() {
this.delegate.close();
this.observationScope.close();
this.observation.stop();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,10 @@ interface RequestHeadersSpec<S extends RequestHeadersSpec<S>> {
S httpRequest(Consumer<ClientHttpRequest> requestConsumer);

/**
* Proceed to declare how to extract the response. For example to extract
* a {@link ResponseEntity} with status, headers, and body:
* Enter the retrieve workflow and use the returned {@link ResponseSpec}
* to select from a number of built-in options to extract the response.
* For example:
*
* <pre class="code">
* ResponseEntity&lt;Person&gt; entity = client.get()
* .uri("/persons/1")
Expand All @@ -632,6 +634,10 @@ interface RequestHeadersSpec<S extends RequestHeadersSpec<S>> {
* .retrieve()
* .body(Person.class);
* </pre>
* Note that this method does not actually execute the request until you
* call one of the returned {@link ResponseSpec}. Use the
* {@link #exchange(ExchangeFunction)} variants if you need to separate
* request execution from response extraction.
* <p>By default, 4xx response code result in a
* {@link HttpClientErrorException} and 5xx response codes in a
* {@link HttpServerErrorException}. To customize error handling, use
Expand Down Expand Up @@ -664,6 +670,7 @@ interface RequestHeadersSpec<S extends RequestHeadersSpec<S>> {
* @param <T> the type the response will be transformed to
* @return the value returned from the exchange function
*/
@Nullable
default <T> T exchange(ExchangeFunction<T> exchangeFunction) {
return exchange(exchangeFunction, true);
}
Expand Down Expand Up @@ -695,6 +702,7 @@ default <T> T exchange(ExchangeFunction<T> exchangeFunction) {
* @param <T> the type the response will be transformed to
* @return the value returned from the exchange function
*/
@Nullable
<T> T exchange(ExchangeFunction<T> exchangeFunction, boolean close);


Expand All @@ -712,6 +720,7 @@ interface ExchangeFunction<T> {
* @return the exchanged type
* @throws IOException in case of I/O errors
*/
@Nullable
T exchange(HttpRequest clientRequest, ConvertibleClientHttpResponse clientResponse) throws IOException;
}

Expand Down

0 comments on commit bff76d7

Please sign in to comment.