diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java index 3bba7da3d88f..ec3f998573ae 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java @@ -18,6 +18,8 @@ import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; import reactor.core.publisher.Mono; @@ -41,12 +43,40 @@ */ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { + private static final Set FILTERED_HEADER_NAMES = Set.of("Priority"); + + + private Predicate headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name); + public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) { super(target, objectName); } + /** + * Add a Predicate that filters the header names to use for data binding. + * Multiple predicates are combined with {@code AND}. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void addHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = this.headerPredicate.and(headerPredicate); + } + + /** + * Set the Predicate that filters the header names to use for data binding. + *

Note that this method resets any previous predicates that may have been + * set, including headers excluded by default such as the RFC 9218 defined + * "Priority" header. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void setHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = headerPredicate; + } + + @Override public Mono> getValuesToBind(ServerWebExchange exchange) { return super.getValuesToBind(exchange).doOnNext(map -> { @@ -56,10 +86,13 @@ public Mono> getValuesToBind(ServerWebExchange exchange) { } HttpHeaders headers = exchange.getRequest().getHeaders(); for (Map.Entry> entry : headers.entrySet()) { + String name = entry.getKey(); + if (!this.headerPredicate.test(entry.getKey())) { + continue; + } List values = entry.getValue(); if (!CollectionUtils.isEmpty(values)) { // For constructor args with @BindParam mapped to the actual header name - String name = entry.getKey(); addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); // Also adapt to Java conventions for setters name = StringUtils.uncapitalize(entry.getKey().replace("-", "")); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java index 159c687d572c..ad876a9ad951 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java @@ -202,6 +202,24 @@ void bindUriVarsAndHeadersAddedConditionally() throws Exception { assertThat(target.getAge()).isEqualTo(25); } + @Test + void headerPredicate() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest.get("/path") + .header("Priority", "u1") + .header("Some-Int-Array", "1") + .header("Another-Int-Array", "1") + .build(); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null); + binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array")); + + Map map = binder.getValuesToBind(exchange).block(); + assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1")); + } + private BindingContext createBindingContext(String methodName, Class... parameterTypes) throws Exception { Object handler = new InitBinderHandler(); Method method = handler.getClass().getMethod(methodName, parameterTypes); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java index c74b37fab8f3..019bf9a75b12 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java @@ -21,12 +21,14 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; import org.springframework.beans.MutablePropertyValues; import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.servlet.HandlerMapping; @@ -51,6 +53,12 @@ */ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { + private static final Set FILTERED_HEADER_NAMES = Set.of("Priority"); + + + private Predicate headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name); + + /** * Create a new instance, with default object name. * @param target the target object to bind onto (or {@code null} @@ -73,6 +81,29 @@ public ExtendedServletRequestDataBinder(@Nullable Object target, String objectNa } + /** + * Add a Predicate that filters the header names to use for data binding. + * Multiple predicates are combined with {@code AND}. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void addHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = this.headerPredicate.and(headerPredicate); + } + + /** + * Set the Predicate that filters the header names to use for data binding. + *

Note that this method resets any previous predicates that may have been + * set, including headers excluded by default such as the RFC 9218 defined + * "Priority" header. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void setHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = headerPredicate; + } + + @Override protected ServletRequestValueResolver createValueResolver(ServletRequest request) { return new ExtendedServletRequestValueResolver(request, this); @@ -93,7 +124,7 @@ protected void addBindValues(MutablePropertyValues mpvs, ServletRequest request) String name = names.nextElement(); Object value = getHeaderValue(httpRequest, name); if (value != null) { - name = name.replace("-", ""); + name = StringUtils.uncapitalize(name.replace("-", "")); addValueIfNotPresent(mpvs, "Header", name, value); } } @@ -118,7 +149,11 @@ private static void addValueIfNotPresent(MutablePropertyValues mpvs, String labe } @Nullable - private static Object getHeaderValue(HttpServletRequest request, String name) { + private Object getHeaderValue(HttpServletRequest request, String name) { + if (!this.headerPredicate.test(name)) { + return null; + } + Enumeration valuesEnum = request.getHeaders(name); if (!valuesEnum.hasMoreElements()) { return null; @@ -141,7 +176,7 @@ private static Object getHeaderValue(HttpServletRequest request, String name) { /** * Resolver of values that looks up URI path variables. */ - private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { + private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) { super(request, dataBinder); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java index 36fd05508cd8..1d7653bdf5af 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java @@ -18,9 +18,11 @@ import java.util.Map; +import jakarta.servlet.ServletRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.ResolvableType; import org.springframework.web.bind.ServletRequestDataBinder; @@ -102,6 +104,22 @@ void uriVarsAndHeadersAddedConditionally() { assertThat(target.getAge()).isEqualTo(25); } + @Test + void headerPredicate() { + TestBinder binder = new TestBinder(); + binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array")); + + MutablePropertyValues mpvs = new MutablePropertyValues(); + request.addHeader("Priority", "u1"); + request.addHeader("Some-Int-Array", "1"); + request.addHeader("Another-Int-Array", "1"); + + binder.addBindValues(mpvs, request); + + assertThat(mpvs.size()).isEqualTo(1); + assertThat(mpvs.get("someIntArray")).isEqualTo("1"); + } + @Test void noUriTemplateVars() { TestBean target = new TestBean(); @@ -116,4 +134,17 @@ void noUriTemplateVars() { private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { } + + private static class TestBinder extends ExtendedServletRequestDataBinder { + + public TestBinder() { + super(null); + } + + @Override + public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) { + super.addBindValues(mpvs, request); + } + } + }