|
19 | 19 | import java.util.Comparator;
|
20 | 20 | import java.util.List;
|
21 | 21 | import java.util.Map;
|
| 22 | +import java.util.Set; |
| 23 | +import java.util.HashSet; |
| 24 | +import java.util.Arrays; |
| 25 | + |
22 | 26 | import java.util.function.Supplier;
|
23 | 27 |
|
24 | 28 | import org.w3c.dom.Element;
|
@@ -307,6 +311,12 @@ static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostP
|
307 | 311 |
|
308 | 312 | private static final String TEMPLATE_EXPRESSION_BEAN_ID = "annotationExpressionTemplateDefaults";
|
309 | 313 |
|
| 314 | + private static final Set<String> CSRF_HANDSHAKE_HANDLER_CLASSES = new HashSet<>(Arrays.asList( |
| 315 | + "org.springframework.web.socket.server.support.WebSocketHttpRequestHandler", |
| 316 | + "org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService", |
| 317 | + "org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService" |
| 318 | + )); |
| 319 | + |
310 | 320 | private final String inboundSecurityInterceptorId;
|
311 | 321 |
|
312 | 322 | private final boolean sameOriginDisabled;
|
@@ -345,16 +355,7 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
|
345 | 355 | }
|
346 | 356 | }
|
347 | 357 | }
|
348 |
| - else if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler" |
349 |
| - .equals(beanClassName)) { |
350 |
| - addCsrfTokenHandshakeInterceptor(bd); |
351 |
| - } |
352 |
| - else if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService" |
353 |
| - .equals(beanClassName)) { |
354 |
| - addCsrfTokenHandshakeInterceptor(bd); |
355 |
| - } |
356 |
| - else if ("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService" |
357 |
| - .equals(beanClassName)) { |
| 358 | + else if (CSRF_HANDSHAKE_HANDLER_CLASSES.contains(beanClassName)) { |
358 | 359 | addCsrfTokenHandshakeInterceptor(bd);
|
359 | 360 | }
|
360 | 361 | }
|
|
0 commit comments