From 6f682ace51332e18cd084ac6ebc9dd845dace900 Mon Sep 17 00:00:00 2001 From: DingHao Date: Thu, 9 Jan 2025 17:32:25 +0800 Subject: [PATCH] Add ClientRegistration codeChallengeMethod to Enable PKCE Closes gh-16382 Signed-off-by: DingHao --- .../registration/ClientRegistration.java | 29 ++++++++++++++- ...ultOAuth2AuthorizationRequestResolver.java | 7 +++- ...uth2AuthorizationRequestResolverTests.java | 37 +++++++++++++++++-- 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 0639a395f89..a4a73f401b2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,6 +71,8 @@ public final class ClientRegistration implements Serializable { private String clientName; + private String codeChallengeMethod; + private ClientRegistration() { } @@ -162,6 +164,14 @@ public String getClientName() { return this.clientName; } + /** + * Returns the codeChallengeMethod of the client or registration. + * @return the codeChallengeMethod + */ + public String getCodeChallengeMethod() { + return this.codeChallengeMethod; + } + @Override public String toString() { // @formatter:off @@ -175,6 +185,7 @@ public String toString() { + '\'' + ", scopes=" + this.scopes + ", providerDetails=" + this.providerDetails + ", clientName='" + this.clientName + '\'' + + ", codeChallengeMethod='" + this.codeChallengeMethod + '\'' + '}'; // @formatter:on } @@ -367,6 +378,8 @@ public static final class Builder implements Serializable { private String clientName; + private String codeChallengeMethod; + private Builder(String registrationId) { this.registrationId = registrationId; } @@ -391,6 +404,7 @@ private Builder(ClientRegistration clientRegistration) { this.configurationMetadata = new HashMap<>(configurationMetadata); } this.clientName = clientRegistration.clientName; + this.codeChallengeMethod = clientRegistration.codeChallengeMethod; } /** @@ -594,6 +608,16 @@ public Builder clientName(String clientName) { return this; } + /** + * Sets the codeChallengeMethod of the client or registration. + * @param codeChallengeMethod the codeChallengeMethod + * @return the {@link Builder} + */ + public Builder codeChallengeMethod(String codeChallengeMethod) { + this.codeChallengeMethod = codeChallengeMethod; + return this; + } + /** * Builds a new {@link ClientRegistration}. * @return a {@link ClientRegistration} @@ -627,12 +651,13 @@ private ClientRegistration create() { clientRegistration.providerDetails = createProviderDetails(clientRegistration); clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName : this.registrationId; + clientRegistration.codeChallengeMethod = this.codeChallengeMethod; return clientRegistration; } private ClientAuthenticationMethod deduceClientAuthenticationMethod(ClientRegistration clientRegistration) { if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType) - && !StringUtils.hasText(this.clientSecret)) { + && (!StringUtils.hasText(this.clientSecret) || StringUtils.hasText(this.codeChallengeMethod))) { return ClientAuthenticationMethod.NONE; } return ClientAuthenticationMethod.CLIENT_SECRET_BASIC; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java index c189317ec43..9530a33bb43 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.web.util.UrlUtils; @@ -185,6 +186,10 @@ private OAuth2AuthorizationRequest.Builder getBuilder(ClientRegistration clientR } if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { DEFAULT_PKCE_APPLIER.accept(builder); + if (StringUtils.hasText(clientRegistration.getCodeChallengeMethod())) { + builder.additionalParameters((params) -> params.put(PkceParameterNames.CODE_CHALLENGE_METHOD, + clientRegistration.getCodeChallengeMethod())); + } } return builder; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index c10a3f82cfc..d1f981cfb13 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { private ClientRegistration registration2; + private ClientRegistration pkceClientRegistration; + private ClientRegistration fineRedirectUriTemplateRegistration; private ClientRegistration publicClientRegistration; @@ -72,6 +74,9 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void setUp() { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build(); + + this.pkceClientRegistration = pkceClientRegistration().build(); + this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build(); // @formatter:off this.publicClientRegistration = TestClientRegistrations.clientRegistration() @@ -86,8 +91,8 @@ public void setUp() { .build(); // @formatter:on this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, - this.registration2, this.fineRedirectUriTemplateRegistration, this.publicClientRegistration, - this.oidcRegistration); + this.registration2, this.pkceClientRegistration, this.fineRedirectUriTemplateRegistration, + this.publicClientRegistration, this.oidcRegistration); this.resolver = new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, this.authorizationRequestBaseUri); } @@ -563,6 +568,32 @@ public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQuery + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "appid=client-id"); } + @Test + public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { + ClientRegistration clientRegistration = this.pkceClientRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getAdditionalParameters().get(PkceParameterNames.CODE_CHALLENGE_METHOD)) + .isEqualTo(clientRegistration.getCodeChallengeMethod()); + } + + private static ClientRegistration.Builder pkceClientRegistration() { + return ClientRegistration.withRegistrationId("pkce") + .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") + .codeChallengeMethod("S256") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://example.com/login/oauth/authorize") + .tokenUri("https://example.com/login/oauth/access_token") + .userInfoUri("https://api.example.com/user") + .userNameAttributeName("id") + .clientName("Client Name") + .clientId("client-id-3") + .clientSecret("client-secret"); + } + private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() { // @formatter:off return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")