Skip to content

Commit

Permalink
Use SessionAuthenticationStrategy for Remember-Me authentication
Browse files Browse the repository at this point in the history
Closes gh-2253
  • Loading branch information
xhaggi authored and jzheaux committed Oct 15, 2024
1 parent d37d41c commit 7f53724
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository;
import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter;
import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -296,6 +297,13 @@ public void configure(H http) {
rememberMeFilter.setSecurityContextRepository(securityContextRepository);
}
rememberMeFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());

SessionAuthenticationStrategy sessionAuthenticationStrategy = http
.getSharedObject(SessionAuthenticationStrategy.class);
if (sessionAuthenticationStrategy != null) {
rememberMeFilter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy);
}

rememberMeFilter = postProcess(rememberMeFilter);
http.addFilter(rememberMeFilter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
Expand All @@ -74,6 +75,7 @@
import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;

Expand Down Expand Up @@ -334,6 +336,27 @@ public void getWhenCustomSecurityContextRepositoryThenUses() throws Exception {
verify(repository).saveContext(any(), any(), any());
}

@Test
public void rememberMeExpiresSessionWhenSessionManagementMaximumSessionsExceeds() throws Exception {
this.spring.register(RememberMeMaximumSessionsConfig.class).autowire();

MockHttpServletRequestBuilder loginRequest = post("/login").with(csrf())
.param("username", "user")
.param("password", "password")
.param("remember-me", "true");
MvcResult mvcResult = this.mvc.perform(loginRequest).andReturn();
Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me");
HttpSession session = mvcResult.getRequest().getSession();

MockHttpServletRequestBuilder exceedsMaximumSessionsRequest = get("/abc").cookie(rememberMeCookie);
this.mvc.perform(exceedsMaximumSessionsRequest);

MockHttpServletRequestBuilder sessionExpiredRequest = get("/abc").cookie(rememberMeCookie)
.session((MockHttpSession) session);
this.mvc.perform(sessionExpiredRequest)
.andExpect(content().string(startsWith("This session has been expired")));
}

@Configuration
@EnableWebSecurity
static class NullUserDetailsConfig {
Expand Down Expand Up @@ -617,6 +640,35 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception {

}

@Configuration
@EnableWebSecurity
static class RememberMeMaximumSessionsConfig {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authorizeRequests) ->
authorizeRequests
.anyRequest().hasRole("USER")
)
.sessionManagement((sessionManagement) ->
sessionManagement
.maximumSessions(1)
)
.formLogin(withDefaults())
.rememberMe(withDefaults());
return http.build();
// @formatter:on
}

@Bean
UserDetailsService userDetailsService() {
return new InMemoryUserDetailsManager(PasswordEncodedUser.user());
}

}

@Configuration
@EnableWebSecurity
static class SecurityContextRepositoryConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -81,6 +83,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements

private SecurityContextRepository securityContextRepository = new HttpSessionSecurityContextRepository();

private SessionAuthenticationStrategy sessionStrategy = new NullAuthenticatedSessionStrategy();

public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager,
RememberMeServices rememberMeServices) {
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
Expand Down Expand Up @@ -115,6 +119,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response,
// Attempt authentication via AuthenticationManager
try {
rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth);
this.sessionStrategy.onAuthentication(rememberMeAuth, request, response);
// Store to SecurityContextHolder
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(rememberMeAuth);
Expand Down Expand Up @@ -211,4 +216,17 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
this.securityContextHolderStrategy = securityContextHolderStrategy;
}

/**
* The session handling strategy which will be invoked immediately after an
* authentication request is successfully processed by the
* <tt>AuthenticationManager</tt>. Used, for example, to handle changing of the
* session identifier to prevent session fixation attacks.
* @param sessionStrategy the implementation to use. If not set a null implementation
* is used.
*/
public void setSessionAuthenticationStrategy(SessionAuthenticationStrategy sessionStrategy) {
Assert.notNull(sessionStrategy, "sessionStrategy cannot be null");
this.sessionStrategy = sessionStrategy;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.security.web.authentication.NullRememberMeServices;
import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.context.SecurityContextRepository;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -170,6 +171,23 @@ public void securityContextRepositoryInvokedIfSet() throws Exception {
verify(securityContextRepository).saveContext(any(), eq(request), eq(response));
}

@Test
public void sessionAuthenticationStrategyInvokedIfSet() throws Exception {
SessionAuthenticationStrategy sessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class);
AuthenticationManager am = mock(AuthenticationManager.class);
given(am.authenticate(this.remembered)).willReturn(this.remembered);
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am,
new MockRememberMeServices(this.remembered));
filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target"));
filter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x");
filter.doFilter(request, response, fc);
verify(sessionAuthenticationStrategy).onAuthentication(any(), eq(request), eq(response));
}

private class MockRememberMeServices implements RememberMeServices {

private Authentication authToReturn;
Expand Down

0 comments on commit 7f53724

Please sign in to comment.