Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
Expand Down Expand Up @@ -249,13 +248,7 @@ public CsrfConfigurer<H> sessionAuthenticationStrategy(
@SuppressWarnings("unchecked")
@Override
public void configure(H http) {
CsrfFilter filter;
if (this.requestHandler != null) {
filter = new CsrfFilter(this.requestHandler);
}
else {
filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
}
CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if (requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
Expand All @@ -272,6 +265,9 @@ public void configure(H http) {
if (sessionConfigurer != null) {
sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
}
if (this.requestHandler != null) {
filter.setRequestHandler(this.requestHandler);
}
filter = postProcess(filter);
http.addFilter(filter);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,6 @@
import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfLogoutHandler;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
Expand Down Expand Up @@ -112,18 +111,13 @@ public BeanDefinition parse(Element element, ParserContext pc) {
new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
}
BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
if (!StringUtils.hasText(this.requestHandlerRef)) {
BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
builder.addConstructorArgValue(csrfTokenRequestHandler);
}
else {
builder.addConstructorArgReference(this.requestHandlerRef);
}
builder.addConstructorArgReference(this.csrfRepositoryRef);
if (StringUtils.hasText(this.requestMatcherRef)) {
builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
}
if (StringUtils.hasText(this.requestHandlerRef)) {
builder.addPropertyReference("requestHandler", this.requestHandlerRef);
}
this.csrfFilter = builder.getBeanDefinition();
return this.csrfFilter;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ csrf-options.attlist &=
## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
attribute token-repository-ref { xsd:token }?
csrf-options.attlist &=
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
attribute request-handler-ref { xsd:token }?

headers =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3258,7 +3258,7 @@
</xs:attribute>
<xs:attribute name="request-handler-ref" type="xs:token">
<xs:annotation>
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
<xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
</xs:documentation>
</xs:annotation>
</xs:attribute>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
Expand Down Expand Up @@ -85,7 +85,7 @@ DefaultSecurityFilterChain springSecurity(HttpSecurity http) throws Exception {
csrfRepository.setDeferLoadToken(true);
HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
requestCache.setMatchingRequestParameterName("continue");
CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler();
CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
requestHandler.setCsrfRequestAttributeName("_csrf");
// @formatter:off
http
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand All @@ -61,7 +63,6 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
Expand Down Expand Up @@ -207,30 +208,30 @@ public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exc
public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("/"));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class));
.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception {
CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn();
this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
.andExpect(redirectedUrl("http://localhost/some-url"));
verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
.loadToken(any(HttpServletRequest.class));
.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
}

// SEC-2422
Expand Down Expand Up @@ -277,11 +278,13 @@ public void requireCsrfProtectionMatcherInLambdaWhenRequestMatchesThenRespondsWi
@Test
public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryConfig.REPO.loadToken(any()))
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/")).andExpect(status().isOk());
verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class));
verify(CsrfTokenRepositoryConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
}

@Test
Expand All @@ -297,8 +300,8 @@ public void logoutWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws E
public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception {
CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken);
given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken);
given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login")
Expand All @@ -314,11 +317,13 @@ public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Ex
@Test
public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception {
CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class);
given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any()))
.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
given(CsrfTokenRepositoryInLambdaConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/")).andExpect(status().isOk());
verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class));
verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
}

@Test
Expand Down Expand Up @@ -418,40 +423,39 @@ public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Except
}

@Test
public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken() throws Exception {
public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/login")).andExpect(status().isOk())
.andExpect(content().string(containsString(csrfToken.getToken())));
verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
any(HttpServletResponse.class));
verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}

@Test
public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception {
public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();

this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.param("username", "user")
.param("password", "password");
// @formatter:on
this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}
Expand Down Expand Up @@ -799,9 +803,11 @@ protected void configure(AuthenticationManagerBuilder auth) throws Exception {

@Configuration
@EnableWebSecurity
static class CsrfTokenRequestProcessorConfig {
static class CsrfTokenRequestHandlerConfig {

static CsrfTokenRepository REPO;

static CsrfTokenRepositoryRequestHandler HANDLER;
static CsrfTokenRequestHandler HANDLER;

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
Expand All @@ -811,7 +817,10 @@ SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
.anyRequest().authenticated()
)
.formLogin(Customizer.withDefaults())
.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER));
.csrf((csrf) -> csrf
.csrfTokenRepository(REPO)
.csrfTokenRequestHandler(HANDLER)
);
// @formatter:on

return http.build();
Expand Down Expand Up @@ -841,4 +850,24 @@ void rootPost() {

}

private static final class TestDeferredCsrfToken implements DeferredCsrfToken {

private final CsrfToken csrfToken;

private TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}

@Override
public CsrfToken get() {
return this.csrfToken;
}

@Override
public boolean isGenerated() {
return false;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.test.SpringTestContext;
Expand All @@ -42,7 +41,6 @@
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
Expand Down Expand Up @@ -546,9 +544,8 @@ static class CsrfCreatedResultMatcher implements ResultMatcher {
@Override
public void match(MvcResult result) {
MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse();
DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response);
assertThat(token.isGenerated()).isFalse();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
assertThat(token).isNotNull();
}

}
Expand All @@ -564,8 +561,7 @@ static class CsrfReturnedResultMatcher implements ResultMatcher {
@Override
public void match(MvcResult result) throws Exception {
MockHttpServletRequest request = result.getRequest();
MockHttpServletResponse response = result.getResponse();
CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get();
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
assertThat(token).isNotNull();
assertThat(token.getToken()).isEqualTo(this.token.apply(result));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<csrf request-handler-ref="requestHandler"/>
</http>

<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="csrf-attribute-name"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
c:delegate-ref="httpSessionCsrfRepository"
p:deferLoadToken="true"/>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="_csrf"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>
Loading