diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java index f75de343045..085643e7b42 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java @@ -362,6 +362,8 @@ public final class RefreshTokenGrantBuilder implements Builder { private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler; + private Duration clockSkew; private Clock clock; @@ -382,6 +384,21 @@ public RefreshTokenGrantBuilder accessTokenResponseClient( return this; } + /** + * Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after + * the client is re-authorized, defaults to + * {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}. + * @param refreshTokenSuccessHandler the + * {@link ReactiveOAuth2AuthorizationSuccessHandler} to use + * @return the {@link RefreshTokenGrantBuilder} + * @since 7.0 + */ + public RefreshTokenGrantBuilder refreshTokenSuccessHandler( + ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) { + this.refreshTokenSuccessHandler = refreshTokenSuccessHandler; + return this; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the access * token expiry. An access token is considered expired if @@ -418,6 +435,9 @@ public ReactiveOAuth2AuthorizedClientProvider build() { if (this.accessTokenResponseClient != null) { authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); } + if (this.refreshTokenSuccessHandler != null) { + authorizedClientProvider.setRefreshTokenSuccessHandler(this.refreshTokenSuccessHandler); + } if (this.clockSkew != null) { authorizedClientProvider.setClockSkew(this.clockSkew); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandler.java new file mode 100644 index 00000000000..612b274a126 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,302 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Duration; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser} + * in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according + * to OpenID + * Connect Core 1.0 - Section 12.2 Successful Refresh Response + * + * @author Evgeniy Cheban + * @since 7.0 + */ +public final class RefreshTokenReactiveOAuth2AuthorizationSuccessHandler + implements ReactiveOAuth2AuthorizationSuccessHandler { + + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; + + private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse"; + + // @formatter:off + private static final Mono currentServerWebExchangeMono = Mono.deferContextual(Mono::just) + .filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)); + // @formatter:on + + private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + + private ReactiveJwtDecoderFactory jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(); + + private ReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + + private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities; + + private Duration clockSkew = Duration.ofSeconds(60); + + @Override + public Mono onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, + Map attributes) { + if (!(principal instanceof OAuth2AuthenticationToken authenticationToken) + || authenticationToken.getClass() != OAuth2AuthenticationToken.class) { + // If the application customizes the authentication result, then a custom + // handler should be provided. + return Mono.empty(); + } + // The current principal must be an OidcUser. + if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) { + return Mono.empty(); + } + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + // The registrationId must match the one used to log in. + if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) { + return Mono.empty(); + } + // Create, validate OidcIdToken and refresh OidcUser in the SecurityContext. + return Mono.zip(serverWebExchange(attributes), accessTokenResponse(attributes)).flatMap((t2) -> { + ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + Map additionalParameters = t2.getT2().getAdditionalParameters(); + return jwtDecoder.decode((String) additionalParameters.get(OidcParameterNames.ID_TOKEN)) + .onErrorMap(JwtException.class, (ex) -> { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), + null); + return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + }) + .map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), + jwt.getClaims())) + .doOnNext((idToken) -> validateIdToken(existingOidcUser, idToken)) + .flatMap((idToken) -> { + OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, + authorizedClient.getAccessToken(), idToken); + return this.userService.loadUser(userRequest); + }) + .flatMap((oidcUser) -> refreshSecurityContext(t2.getT1(), clientRegistration, authenticationToken, + oidcUser)); + }); + } + + private Mono serverWebExchange(Map attributes) { + if (attributes.get(ServerWebExchange.class.getName()) instanceof ServerWebExchange exchange) { + return Mono.just(exchange); + } + return currentServerWebExchangeMono; + } + + private Mono accessTokenResponse(Map attributes) { + if (attributes.get(OAuth2AccessTokenResponse.class.getName()) instanceof OAuth2AccessTokenResponse response) { + return Mono.just(response); + } + return Mono.empty(); + } + + private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) { + // OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response + // If an ID Token is returned as a result of a token refresh request, the + // following requirements apply: + // its iss Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateIssuer(existingOidcUser, idToken); + // its sub Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateSubject(existingOidcUser, idToken); + // its iat Claim MUST represent the time that the new ID Token is issued, + validateIssuedAt(existingOidcUser, idToken); + // its aud Claim Value MUST be the same as in the ID Token issued when the + // original authentication occurred, + validateAudience(existingOidcUser, idToken); + // if the ID Token contains an auth_time Claim, its value MUST represent the time + // of the original authentication - not the time that the new ID token is issued, + validateAuthenticatedAt(existingOidcUser, idToken); + // it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of + // the original authentication contained nonce; however, if it is present, its + // value MUST be the same as in the ID Token issued at the time of the original + // authentication, + validateNonce(existingOidcUser, idToken); + } + + private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!isValidAudience(existingOidcUser, idToken)) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) { + List idTokenAudiences = idToken.getAudience(); + Set oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience()); + if (idTokenAudiences.size() != oidcUserAudiences.size()) { + return false; + } + for (String audience : idTokenAudiences) { + if (!oidcUserAudiences.contains(audience)) { + return false; + } + } + return true; + } + + private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) { + if (idToken.getAuthenticatedAt() == null) { + return; + } + if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) { + if (!StringUtils.hasText(idToken.getNonce())) { + return; + } + if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce", + REFRESH_TOKEN_RESPONSE_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private Mono refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration, + OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) { + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oidcUser.getAuthorities()); + OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities, + clientRegistration.getRegistrationId()); + authenticationResult.setDetails(authenticationToken.getDetails()); + SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult); + return this.serverSecurityContextRepository.save(exchange, securityContext); + } + + /** + * Sets a {@link ServerSecurityContextRepository} to use for refreshing a + * {@link SecurityContext}, defaults to + * {@link WebSessionServerSecurityContextRepository}. + * @param serverSecurityContextRepository the {@link ServerSecurityContextRepository} + * to use + */ + public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) { + Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null"); + this.serverSecurityContextRepository = serverSecurityContextRepository; + } + + /** + * Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc + * id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}. + * @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use + */ + public void setJwtDecoderFactory(ReactiveJwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + + /** + * Sets a {@link GrantedAuthoritiesMapper} to use for mapping + * {@link GrantedAuthority}s, defaults to no-op implementation. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use + */ + public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { + Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); + this.authoritiesMapper = authoritiesMapper; + } + + /** + * Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser} + * from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}. + * @param userService the {@link ReactiveOAuth2UserService} to use + */ + public void setUserService(ReactiveOAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OidcIdToken#getIssuedAt()} to match the existing + * {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds. + * @param clockSkew the maximum acceptable clock skew to use + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java index 523fe303bbf..5d50d738a4d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import reactor.core.publisher.Mono; @@ -33,6 +35,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; /** @@ -40,6 +43,7 @@ * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. * * @author Joe Grandja + * @author Evgeniy Cheban * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider * @see WebClientReactiveRefreshTokenTokenResponseClient @@ -49,6 +53,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -96,8 +102,16 @@ public Mono authorize(OAuth2AuthorizationContext context .flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e)) - .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); + .flatMap((tokenResponse) -> { + OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, + context.getPrincipal().getName(), tokenResponse.getAccessToken(), + tokenResponse.getRefreshToken()); + Map attributes = new HashMap<>(context.getAttributes()); + attributes.put(OAuth2AccessTokenResponse.class.getName(), tokenResponse); + return this.refreshTokenSuccessHandler + .onAuthorizationSuccess(refreshedAuthorizedClient, context.getPrincipal(), attributes) + .then(Mono.just(refreshedAuthorizedClient)); + }); } private boolean hasTokenExpired(OAuth2Token token) { @@ -116,6 +130,19 @@ public void setAccessTokenResponseClient( this.accessTokenResponseClient = accessTokenResponseClient; } + /** + * Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after the + * client is re-authorized, defaults to + * {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}. + * @param refreshTokenSuccessHandler the + * {@link ReactiveOAuth2AuthorizationSuccessHandler} to use + * @since 7.0 + */ + public void setRefreshTokenSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) { + Assert.notNull(refreshTokenSuccessHandler, "refreshTokenSuccessHandler cannot be null"); + this.refreshTokenSuccessHandler = refreshTokenSuccessHandler; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index 17e4c1c7b40..4473f65c6b4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,7 @@ * * @author Joe Grandja * @author Phil Clay + * @author Evgeniy Cheban * @since 5.2 * @see ReactiveOAuth2AuthorizedClientManager * @see ReactiveOAuth2AuthorizedClientProvider @@ -319,10 +320,10 @@ public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) return Mono.justOrEmpty(serverWebExchange) .switchIfEmpty(currentServerWebExchangeMono) .flatMap((exchange) -> { - Map contextAttributes = Collections.emptyMap(); + Map contextAttributes = new HashMap<>(); + contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope)) { - contextAttributes = new HashMap<>(); contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, StringUtils.delimitedListToStringArray(scope, " ")); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 2d9be5ebf44..8653dc59664 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -50,6 +50,8 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.client.ClientRequest; @@ -95,6 +97,7 @@ * @author Rob Winch * @author Joe Grandja * @author Phil Clay + * @author Evgeniy Cheban * @since 5.1 */ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { @@ -145,6 +148,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private ClientResponseHandler clientResponseHandler; + // This should be replaced with PrincipalResolver introduced in gh-16284 + private final ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + /** * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the * provided parameters. @@ -336,8 +342,11 @@ public Mono filter(ClientRequest request, ExchangeFunction next) } private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { - return next.exchange(request) - .transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono)); + // Re-request an Authentication from serverSecurityContextRepository since it + // might have been changed during provider invocation. + return effectiveAuthentication(request).flatMap((authentication) -> next.exchange(request) + .transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono)) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))); } private Mono authorizedClient(ClientRequest request) { @@ -368,6 +377,17 @@ private Mono authorizeRequest(ClientRequest request) { // @formatter:on } + private Mono effectiveAuthentication(ClientRequest request) { + // @formatter:off + return effectiveServerWebExchange(request) + .filter(Optional::isPresent) + .map(Optional::get) + .flatMap(this.serverSecurityContextRepository::load) + .map(SecurityContext::getAuthentication) + .switchIfEmpty(this.currentAuthenticationMono); + // @formatter:on + } + /** * Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active * for the given request. diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests.java new file mode 100644 index 00000000000..deb2b972910 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests.java @@ -0,0 +1,412 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; +import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThatException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}. + * + * @author Evgeniy Cheban + */ +class RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests { + + @Test + void onAuthorizationSuccessWhenIdTokenValidThenSecurityContextRefreshed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes).block(); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyComplete(); + StepVerifier.create(serverSecurityContextRepository.load(exchange).map(SecurityContext::getAuthentication)) + .expectNext(authenticationToken) + .verifyComplete(); + } + + @Test + void onAuthorizationSuccessWhenIdTokenIssuerNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", "https://issuer.com"); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid issuer"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenSubNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", "invalid_sub"); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid subject"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenIatNotAfterThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt().minus(Duration.ofDays(1))); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid issued at time"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAudEmptyThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", Collections.emptyList()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid audience"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAudNotContainThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", List.of("invalid_client-id")); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid audience"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenAuthTimeNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("auth_time", principal.getIssuedAt()); + claims.put("nonce", principal.getNonce()); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_id_token] Invalid authenticated at time"); + } + + @Test + void onAuthorizationSuccessWhenIdTokenNonceNotSameThenException() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + DefaultOidcUser principal = TestOidcUsers.create(); + OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal, + principal.getAuthorities(), clientRegistration.getRegistrationId()); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + accessToken); + OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse(); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); + Map attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse, + ServerWebExchange.class.getName(), exchange); + Map claims = new HashMap<>(); + claims.put("iss", principal.getIssuer()); + claims.put("sub", principal.getSubject()); + claims.put("aud", principal.getAudience()); + claims.put("nonce", "invalid_nonce"); + Jwt jwt = mock(Jwt.class); + given(jwt.getTokenValue()).willReturn("id-token-1234"); + given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt()); + given(jwt.getClaims()).willReturn(claims); + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt)); + ReactiveJwtDecoderFactory reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class); + given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder); + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + given(userService.loadUser(any())).willReturn(Mono.just(principal)); + WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler(); + handler.setJwtDecoderFactory(reactiveJwtDecoderFactory); + handler.setUserService(userService); + handler.setServerSecurityContextRepository(serverSecurityContextRepository); + StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes)) + .verifyErrorMessage("[invalid_nonce] Invalid nonce"); + } + + @Test + void setServerSecurityContextRepositoryWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler() + .setServerSecurityContextRepository(null)) + .withMessage("serverSecurityContextRepository cannot be null"); + } + + @Test + void setJwtDecoderFactoryWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setJwtDecoderFactory(null)) + .withMessage("jwtDecoderFactory cannot be null"); + } + + @Test + void setAuthoritiesMapperWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setAuthoritiesMapper(null)) + .withMessage("authoritiesMapper cannot be null"); + } + + @Test + void setUserServiceWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setUserService(null)) + .withMessage("userService cannot be null"); + } + + @Test + void setClockSkewWhenNullThenException() { + assertThatException() + .isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + } + + private static OAuth2AccessToken createAccessToken() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + } + + private static OAuth2AccessTokenResponse createAccessTokenResponse() { + return OAuth2AccessTokenResponse.withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "id-token-1234")) + .build(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java index 3e438c60bb9..39f0a13d711 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ * Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. * * @author Joe Grandja + * @author Evgeniy Cheban */ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @@ -84,6 +85,15 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument .withMessage("accessTokenResponseClient cannot be null"); } + @Test + public void setRefreshTokenSuccessHandlerWhenHandlerIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setRefreshTokenSuccessHandler(null)) + .withMessage("refreshTokenSuccessHandler cannot be null"); + // @formatter:on + } + @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { // @formatter:off