Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
Expand Down Expand Up @@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce

private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();

private Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter = this::createAuthenticationResult;

/**
* Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided
* parameters.
Expand Down Expand Up @@ -190,9 +193,9 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ
authenticationRequest.setDetails(authenticationDetails);
OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this
.getAuthenticationManager().authenticate(authenticationRequest);
OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken(
authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter
.convert(authenticationResult);
Assert.notNull(oauth2Authentication, "authentication result cannot be null");
oauth2Authentication.setDetails(authenticationDetails);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(), oauth2Authentication.getName(),
Expand All @@ -213,4 +216,22 @@ public final void setAuthorizationRequestRepository(
this.authorizationRequestRepository = authorizationRequestRepository;
}

/**
* Sets the converter responsible for converting from
* {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken}
* authentication result.
* @param authenticationResultConverter the converter for
* {@link OAuth2AuthenticationToken}'s
*/
public final void setAuthenticationResultConverter(
Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter) {
Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null");
this.authenticationResultConverter = authenticationResultConverter;
}

private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) {
return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.security.oauth2.client.web;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -33,10 +34,12 @@
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
Expand Down Expand Up @@ -152,6 +155,12 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
}

// gh-10033
@Test
public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
}

@Test
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
String requestUri = "/path";
Expand Down Expand Up @@ -416,6 +425,41 @@ public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationR
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
}

// gh-10033
@Test
public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
this.filter.setAuthenticationResultConverter((authentication) -> null);
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1);
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
}

// gh-10033
@Test
public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
this.filter.setAuthenticationResultConverter(
(authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1);
Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
}

private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration, String state) {
Map<String, Object> attributes = new HashMap<>();
Expand Down Expand Up @@ -454,4 +498,13 @@ private void setUpAuthenticationResult(ClientRegistration registration) {
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
}

private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {

CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
String authorizedClientRegistrationId) {
super(principal, authorities, authorizedClientRegistrationId);
}

}

}