diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index edb6646b363..2c903ce5270 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -16,9 +16,13 @@ package org.springframework.security.config.websocket; +import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import org.w3c.dom.Element; @@ -307,6 +311,11 @@ static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostP private static final String TEMPLATE_EXPRESSION_BEAN_ID = "annotationExpressionTemplateDefaults"; + private static final Set CSRF_HANDSHAKE_HANDLER_CLASSES = Collections.unmodifiableSet( + new HashSet<>(Arrays.asList("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler", + "org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService", + "org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService"))); + private final String inboundSecurityInterceptorId; private final boolean sameOriginDisabled; @@ -345,16 +354,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t } } } - else if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler" - .equals(beanClassName)) { - addCsrfTokenHandshakeInterceptor(bd); - } - else if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService" - .equals(beanClassName)) { - addCsrfTokenHandshakeInterceptor(bd); - } - else if ("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService" - .equals(beanClassName)) { + else if (CSRF_HANDSHAKE_HANDLER_CLASSES.contains(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } }