From 270b487b1c5c0a46eb58fb7e10305d72c7c9abe5 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Thu, 1 Jul 2021 12:55:12 -0500 Subject: [PATCH] Add converter for authentication result in OAuth2LoginAuthenticationFilter Closes gh-10033 --- .../web/OAuth2LoginAuthenticationFilter.java | 29 ++++++++-- .../OAuth2LoginAuthenticationFilterTests.java | 55 ++++++++++++++++++- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index 6943215ded4..9bebb0869c8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + private Converter authenticationResultConverter = this::createAuthenticationResult; + /** * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided * parameters. @@ -190,9 +193,9 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ authenticationRequest.setDetails(authenticationDetails); OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this .getAuthenticationManager().authenticate(authenticationRequest); - OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken( - authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), - authenticationResult.getClientRegistration().getRegistrationId()); + OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter + .convert(authenticationResult); + Assert.notNull(oauth2Authentication, "authentication result cannot be null"); oauth2Authentication.setDetails(authenticationDetails); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( authenticationResult.getClientRegistration(), oauth2Authentication.getName(), @@ -213,4 +216,22 @@ public final void setAuthorizationRequestRepository( this.authorizationRequestRepository = authorizationRequestRepository; } + /** + * Sets the converter responsible for converting from + * {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken} + * authentication result. + * @param authenticationResultConverter the converter for + * {@link OAuth2AuthenticationToken}'s + */ + public final void setAuthenticationResultConverter( + Converter authenticationResultConverter) { + Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null"); + this.authenticationResultConverter = authenticationResultConverter; + } + + private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) { + return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), + authenticationResult.getClientRegistration().getRegistrationId()); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 411b8573ae0..d399f3c3dbe 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.web; +import java.util.Collection; import java.util.HashMap; import java.util.Map; @@ -33,10 +34,12 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -152,6 +155,12 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } + // gh-10033 + @Test + public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null)); + } + @Test public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception { String requestUri = "/path"; @@ -416,6 +425,41 @@ public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationR assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails); } + // gh-10033 + @Test + public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception { + this.filter.setAuthenticationResultConverter((authentication) -> null); + String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); + String state = "state"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletResponse response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(request, response, this.registration1, state); + this.setUpAuthenticationResult(this.registration1); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response)); + } + + // gh-10033 + @Test + public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() { + this.filter.setAuthenticationResultConverter( + (authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(), + authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId())); + String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); + String state = "state"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletResponse response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(request, response, this.registration1, state); + this.setUpAuthenticationResult(this.registration1); + Authentication authenticationResult = this.filter.attemptAuthentication(request, response); + assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class); + } + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration, String state) { Map attributes = new HashMap<>(); @@ -454,4 +498,13 @@ private void setUpAuthenticationResult(ClientRegistration registration) { given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication); } + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + }