Skip to content

Refactor DefaultReactiveElasticsearchClient to do request customizati… #1795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,23 @@ private static WebClientProvider getWebClientProvider(ClientConfiguration client
scheme = "https";
}

ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient);
WebClientProvider provider = WebClientProvider.create(scheme, connector);
WebClientProvider provider = WebClientProvider.create(scheme, new ReactorClientHttpConnector(httpClient));

if (clientConfiguration.getPathPrefix() != null) {
provider = provider.withPathPrefix(clientConfiguration.getPathPrefix());
}

provider = provider.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer());
provider = provider //
.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()) //
.withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> {
HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get();

if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
httpHeaders.addAll(suppliedHeaders);
}
}));

return provider;
}

Expand Down Expand Up @@ -584,12 +592,6 @@ private RequestBodySpec sendRequest(WebClient webClient, String logId, Request r
request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue()));
}
}

// plus the ones from the supplier
HttpHeaders suppliedHeaders = headersSupplier.get();
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
theHeaders.addAll(suppliedHeaders);
}
});

if (request.getEntity() != null) {
Expand All @@ -599,8 +601,8 @@ private RequestBodySpec sendRequest(WebClient webClient, String logId, Request r
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters(),
body::get);

requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue()));
requestBodySpec.body(Mono.fromSupplier(body), String.class);
requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue()))
.body(Mono.fromSupplier(body), String.class);
} else {
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class DefaultWebClientProvider implements WebClientProvider {
private final HttpHeaders headers;
private final @Nullable String pathPrefix;
private final Function<WebClient, WebClient> webClientConfigurer;
private final Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer;

/**
* Create new {@link DefaultWebClientProvider} with empty {@link HttpHeaders} and no-op {@literal error listener}.
Expand All @@ -56,7 +57,7 @@ class DefaultWebClientProvider implements WebClientProvider {
* @param connector can be {@literal null}.
*/
DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector) {
this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity());
this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity(), requestHeadersSpec -> {});
}

/**
Expand All @@ -66,18 +67,21 @@ class DefaultWebClientProvider implements WebClientProvider {
* @param connector can be {@literal null}.
* @param errorListener must not be {@literal null}.
* @param headers must not be {@literal null}.
* @param pathPrefix can be {@literal null}
* @param pathPrefix can be {@literal null}.
* @param webClientConfigurer must not be {@literal null}.
* @param requestConfigurer must not be {@literal null}.
*/
private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector,
Consumer<Throwable> errorListener, HttpHeaders headers, @Nullable String pathPrefix,
Function<WebClient, WebClient> webClientConfigurer) {
Function<WebClient, WebClient> webClientConfigurer, Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {

Assert.notNull(scheme, "Scheme must not be null! A common scheme would be 'http'.");
Assert.notNull(errorListener, "errorListener must not be null! You may want use a no-op one 'e -> {}' instead.");
Assert.notNull(headers, "headers must not be null! Think about using 'HttpHeaders.EMPTY' as an alternative.");
Assert.notNull(webClientConfigurer,
"webClientConfigurer must not be null! You may want use a no-op one 'Function.identity()' instead.");
Assert.notNull(requestConfigurer,
"requestConfigurer must not be null! You may want use a no-op one 'r -> {}' instead.\"");

this.cachedClients = new ConcurrentHashMap<>();
this.scheme = scheme;
Expand All @@ -86,6 +90,7 @@ private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector co
this.headers = headers;
this.pathPrefix = pathPrefix;
this.webClientConfigurer = webClientConfigurer;
this.requestConfigurer = requestConfigurer;
}

@Override
Expand All @@ -106,6 +111,7 @@ public Consumer<Throwable> getErrorListener() {
return this.errorListener;
}

@Nullable
@Override
public String getPathPrefix() {
return pathPrefix;
Expand All @@ -120,7 +126,17 @@ public WebClientProvider withDefaultHeaders(HttpHeaders headers) {
merged.addAll(this.headers);
merged.addAll(headers);

return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer);
return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer,
requestConfigurer);
}

@Override
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {

Assert.notNull(requestConfigurer, "requestConfigurer must not be null.");

return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);
}

@Override
Expand All @@ -129,26 +145,30 @@ public WebClientProvider withErrorListener(Consumer<Throwable> errorListener) {
Assert.notNull(errorListener, "Error listener must not be null.");

Consumer<Throwable> listener = this.errorListener.andThen(errorListener);
return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer);
return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);
}

@Override
public WebClientProvider withPathPrefix(String pathPrefix) {
Assert.notNull(pathPrefix, "pathPrefix must not be null.");

return new DefaultWebClientProvider(this.scheme, this.connector, this.errorListener, this.headers, pathPrefix,
webClientConfigurer);
webClientConfigurer, requestConfigurer);
}

@Override
public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer) {
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer);
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
requestConfigurer);

}

protected WebClient createWebClientForSocketAddress(InetSocketAddress socketAddress) {

Builder builder = WebClient.builder().defaultHeaders(it -> it.addAll(getDefaultHeaders()));
Builder builder = WebClient.builder() //
.defaultHeaders(it -> it.addAll(getDefaultHeaders())) //
.defaultRequest(requestConfigurer);

if (connector != null) {
builder = builder.clientConnector(connector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ static HostProvider<?> provider(WebClientProvider clientProvider, Supplier<HttpH
Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to.");

if (endpoints.length == 1) {
return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]);
return new SingleNodeHostProvider(clientProvider, endpoints[0]);
} else {
return new MultiNodeHostProvider(clientProvider, headersSupplier, endpoints);
return new MultiNodeHostProvider(clientProvider, endpoints);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
import org.springframework.data.elasticsearch.client.NoReachableHostException;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -53,14 +51,11 @@ class MultiNodeHostProvider implements HostProvider<MultiNodeHostProvider> {
private final static Logger LOG = LoggerFactory.getLogger(MultiNodeHostProvider.class);

private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final Map<InetSocketAddress, ElasticsearchHost> hosts;

MultiNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier,
InetSocketAddress... endpoints) {
MultiNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress... endpoints) {

this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.hosts = new ConcurrentHashMap<>();
for (InetSocketAddress endpoint : endpoints) {
this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN));
Expand Down Expand Up @@ -166,7 +161,6 @@ private Flux<Tuple2<InetSocketAddress, State>> checkNodes(@Nullable State state)

Mono<ClientResponse> clientResponseMono = createWebClient(host) //
.head().uri("/") //
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchangeToMono(Mono::just) //
.timeout(Duration.ofSeconds(1)) //
.doOnError(throwable -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@

import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.function.Supplier;

import org.springframework.data.elasticsearch.client.ElasticsearchHost;
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
import org.springframework.data.elasticsearch.client.NoReachableHostException;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand All @@ -38,15 +36,12 @@
class SingleNodeHostProvider implements HostProvider<SingleNodeHostProvider> {

private final WebClientProvider clientProvider;
private final Supplier<HttpHeaders> headersSupplier;
private final InetSocketAddress endpoint;
private volatile ElasticsearchHost state;

SingleNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier,
InetSocketAddress endpoint) {
SingleNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress endpoint) {

this.clientProvider = clientProvider;
this.headersSupplier = headersSupplier;
this.endpoint = endpoint;
this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN);
}
Expand All @@ -60,7 +55,6 @@ public Mono<ClusterInformation> clusterInfo() {

return createWebClient(endpoint) //
.head().uri("/") //
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
.exchangeToMono(it -> {
if (it.statusCode().isError()) {
state = ElasticsearchHost.offline(endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con

/**
* Obtain the {@link String pathPrefix} to be used.
*
*
* @return the pathPrefix if set.
* @since 4.0
*/
Expand All @@ -126,7 +126,7 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con

/**
* Create a new instance of {@link WebClientProvider} where HTTP requests are called with the given path prefix.
*
*
* @param pathPrefix Path prefix to add to requests
* @return new instance of {@link WebClientProvider}
* @since 4.0
Expand All @@ -136,10 +136,20 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con
/**
* Create a new instance of {@link WebClientProvider} calling the given {@link Function} to configure the
* {@link WebClient}.
*
*
* @param webClientConfigurer configuration function
* @return new instance of {@link WebClientProvider}
* @since 4.0
*/
WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer);

/**
* Create a new instance of {@link WebClientProvider} calling the given {@link Consumer} to configure the requests of
* this {@link WebClient}.
*
* @param requestConfigurer request configuration callback
* @return new instance of {@link WebClientProvider}
* @since 4.3
*/
WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ public static <T extends HostProvider<T>> MockDelegatingElasticsearchHostProvide

if (hosts.length == 1) {
// noinspection unchecked
delegate = (T) new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {};
delegate = (T) new SingleNodeHostProvider(clientProvider, getInetSocketAddress(hosts[0])) {};
} else {
// noinspection unchecked
delegate = (T) new MultiNodeHostProvider(clientProvider, HttpHeaders::new, Arrays.stream(hosts)
delegate = (T) new MultiNodeHostProvider(clientProvider, Arrays.stream(hosts)
.map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {};
}

Expand Down Expand Up @@ -297,6 +297,11 @@ public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient>
throw new UnsupportedOperationException("not implemented");
}

@Override
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
throw new UnsupportedOperationException("not implemented");
}

public Send when(String host) {
InetSocketAddress inetSocketAddress = getInetSocketAddress(host);
return new CallbackImpl(get(host), headersUriSpecMap.get(inetSocketAddress),
Expand Down