Skip to content

Commit

Permalink
Introduce ForwardedHeaderFilter for WebFlux
Browse files Browse the repository at this point in the history
This commit introduces a ForwardedHeaderFilter for WebFlux, similar to
the existing Servlet version. As part of this the
DefaultServerHttpRequestBuilder had to be changed to no longer use
delegation, but instead use a deep copy at the point of mutate().
Otherwise, headers could not be removed.

Issue: SPR-15954
  • Loading branch information
poutsma committed Sep 14, 2017
1 parent 69af698 commit e70210a
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@

package org.springframework.http.server.reactive;

import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import reactor.core.publisher.Flux;

import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.RequestPath;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
* Package-private default implementation of {@link ServerHttpRequest.Builder}.
Expand All @@ -34,36 +44,66 @@
*/
class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {

private final ServerHttpRequest delegate;
private URI uri;

private HttpHeaders httpHeaders;

private String httpMethodValue;

private final MultiValueMap<String, HttpCookie> cookies;

@Nullable
private HttpMethod httpMethod;
private final InetSocketAddress remoteAddress;

@Nullable
private String path;
private String uriPath;

@Nullable
private String contextPath;

@Nullable
private HttpHeaders httpHeaders;
private Flux<DataBuffer> body;

public DefaultServerHttpRequestBuilder(ServerHttpRequest original) {
Assert.notNull(original, "ServerHttpRequest is required");

this.uri = original.getURI();
this.httpMethodValue = original.getMethodValue();
this.remoteAddress = original.getRemoteAddress();
this.body = original.getBody();

this.httpHeaders = new HttpHeaders();
copyMultiValueMap(original.getHeaders(), this.httpHeaders);

public DefaultServerHttpRequestBuilder(ServerHttpRequest delegate) {
Assert.notNull(delegate, "ServerHttpRequest delegate is required");
this.delegate = delegate;
this.cookies = new LinkedMultiValueMap<>(original.getCookies().size());
copyMultiValueMap(original.getCookies(), this.cookies);
}

private static <K, V> void copyMultiValueMap(MultiValueMap<K,V> source,
MultiValueMap<K,V> destination) {

for (Map.Entry<K, List<V>> entry : source.entrySet()) {
K key = entry.getKey();
List<V> values = new LinkedList<>(entry.getValue());
destination.put(key, values);
}
}


@Override
public ServerHttpRequest.Builder method(HttpMethod httpMethod) {
this.httpMethod = httpMethod;
this.httpMethodValue = httpMethod.name();
return this;
}

@Override
public ServerHttpRequest.Builder uri(URI uri) {
this.uri = uri;
return this;
}

@Override
public ServerHttpRequest.Builder path(String path) {
this.path = path;
this.uriPath = path;
return this;
}

Expand All @@ -75,111 +115,79 @@ public ServerHttpRequest.Builder contextPath(String contextPath) {

@Override
public ServerHttpRequest.Builder header(String key, String value) {
if (this.httpHeaders == null) {
this.httpHeaders = new HttpHeaders();
}
this.httpHeaders.add(key, value);
return this;
}

@Override
public ServerHttpRequest.Builder headers(Consumer<HttpHeaders> headersConsumer) {
Assert.notNull(headersConsumer, "'headersConsumer' must not be null");
headersConsumer.accept(this.httpHeaders);
return this;
}

@Override
public ServerHttpRequest build() {
URI uriToUse = getUriToUse();
RequestPath path = getRequestPathToUse(uriToUse);
HttpHeaders headers = getHeadersToUse();
return new MutativeDecorator(this.delegate, this.httpMethod, uriToUse, path, headers);
return new DefaultServerHttpRequest(uriToUse, this.contextPath, this.httpHeaders,
this.httpMethodValue, this.cookies, this.remoteAddress, this.body);

}

@Nullable
private URI getUriToUse() {
if (this.path == null) {
return null;
if (this.uriPath == null) {
return this.uri;
}
URI uri = this.delegate.getURI();
try {
return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(),
this.path, uri.getQuery(), uri.getFragment());
return new URI(this.uri.getScheme(), this.uri.getUserInfo(), uri.getHost(), uri.getPort(),
uriPath, uri.getQuery(), uri.getFragment());
}
catch (URISyntaxException ex) {
throw new IllegalStateException("Invalid URI path: \"" + this.path + "\"");
}
}

@Nullable
private RequestPath getRequestPathToUse(@Nullable URI uriToUse) {
if (uriToUse == null && this.contextPath == null) {
return null;
}
else if (uriToUse == null) {
return this.delegate.getPath().modifyContextPath(this.contextPath);
}
else {
return RequestPath.parse(uriToUse, this.contextPath);
}
}

@Nullable
private HttpHeaders getHeadersToUse() {
if (this.httpHeaders != null) {
HttpHeaders headers = new HttpHeaders();
headers.putAll(this.delegate.getHeaders());
headers.putAll(this.httpHeaders);
return headers;
}
else {
return null;
throw new IllegalStateException("Invalid URI path: \"" + this.uriPath + "\"");
}
}

private static class DefaultServerHttpRequest extends AbstractServerHttpRequest {

/**
* An immutable wrapper of a request returning property overrides -- given
* to the constructor -- or original values otherwise.
*/
private static class MutativeDecorator extends ServerHttpRequestDecorator {

@Nullable
private final HttpMethod httpMethod;

@Nullable
private final URI uri;
private final String methodValue;

@Nullable
private final RequestPath requestPath;
private final MultiValueMap<String, HttpCookie> cookies;

@Nullable
private final HttpHeaders httpHeaders;


public MutativeDecorator(ServerHttpRequest delegate, @Nullable HttpMethod method,
@Nullable URI uri, @Nullable RequestPath requestPath, @Nullable HttpHeaders httpHeaders) {

super(delegate);
this.httpMethod = method;
this.uri = uri;
this.requestPath = requestPath;
this.httpHeaders = httpHeaders;
private final InetSocketAddress remoteAddress;

private final Flux<DataBuffer> body;

public DefaultServerHttpRequest(URI uri, @Nullable String contextPath,
HttpHeaders headers, String methodValue,
MultiValueMap<String, HttpCookie> cookies, @Nullable InetSocketAddress remoteAddress,
Flux<DataBuffer> body) {
super(uri, contextPath, headers);
this.methodValue = methodValue;
this.cookies = cookies;
this.remoteAddress = remoteAddress;
this.body = body;
}

@Override
@Nullable
public HttpMethod getMethod() {
return (this.httpMethod != null ? this.httpMethod : super.getMethod());
public String getMethodValue() {
return this.methodValue;
}

@Override
public URI getURI() {
return (this.uri != null ? this.uri : super.getURI());
protected MultiValueMap<String, HttpCookie> initCookies() {
return this.cookies;
}

@Nullable
@Override
public RequestPath getPath() {
return (this.requestPath != null ? this.requestPath : super.getPath());
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}

@Override
public HttpHeaders getHeaders() {
return (this.httpHeaders != null ? this.httpHeaders : super.getHeaders());
public Flux<DataBuffer> getBody() {
return this.body;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package org.springframework.http.server.reactive;

import java.net.InetSocketAddress;
import java.net.URI;
import java.util.function.Consumer;

import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.ReactiveHttpInputMessage;
Expand Down Expand Up @@ -79,6 +82,11 @@ interface Builder {
*/
Builder method(HttpMethod httpMethod);

/**
* Set the URI to return.
*/
Builder uri(URI uri);

/**
* Set the path to use instead of the {@code "rawPath"} of
* {@link ServerHttpRequest#getURI()}.
Expand All @@ -95,6 +103,17 @@ interface Builder {
*/
Builder header(String key, String value);

/**
* Manipulate this request's headers with the given consumer. The
* headers provided to the consumer are "live", so that the consumer can be used to
* {@linkplain HttpHeaders#set(String, String) overwrite} existing header values,
* {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other
* {@link HttpHeaders} methods.
* @param headersConsumer a function that consumes the {@code HttpHeaders}
* @return this builder
*/
Builder headers(Consumer<HttpHeaders> headersConsumer);

/**
* Build a {@link ServerHttpRequest} decorator with the mutated properties.
*/
Expand Down
Loading

0 comments on commit e70210a

Please sign in to comment.