diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurer.java new file mode 100644 index 0000000000..acbe822191 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurer.java @@ -0,0 +1,164 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.server.resource; + +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.config.annotation.web.HttpSecurityBuilder; +import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.server.resource.authentication.DPoPAuthenticationProvider; +import org.springframework.security.oauth2.server.resource.authentication.DPoPAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationEntryPointFailureHandler; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationFilter; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.HttpStatusEntryPoint; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Joe Grandja + * @since 6.5 + * @see DPoPAuthenticationProvider + */ +final class DPoPAuthenticationConfigurer> + extends AbstractHttpConfigurer, B> { + + private RequestMatcher requestMatcher; + + private AuthenticationConverter authenticationConverter; + + private AuthenticationSuccessHandler authenticationSuccessHandler; + + private AuthenticationFailureHandler authenticationFailureHandler; + + @Override + public void configure(B http) { + AuthenticationManager authenticationManager = http.getSharedObject(AuthenticationManager.class); + http.authenticationProvider(new DPoPAuthenticationProvider(authenticationManager)); + AuthenticationFilter authenticationFilter = new AuthenticationFilter(authenticationManager, + getAuthenticationConverter()); + authenticationFilter.setRequestMatcher(getRequestMatcher()); + authenticationFilter.setSuccessHandler(getAuthenticationSuccessHandler()); + authenticationFilter.setFailureHandler(getAuthenticationFailureHandler()); + authenticationFilter.setSecurityContextRepository(new RequestAttributeSecurityContextRepository()); + authenticationFilter = postProcess(authenticationFilter); + http.addFilter(authenticationFilter); + } + + private RequestMatcher getRequestMatcher() { + if (this.requestMatcher == null) { + this.requestMatcher = new DPoPRequestMatcher(); + } + return this.requestMatcher; + } + + private AuthenticationConverter getAuthenticationConverter() { + if (this.authenticationConverter == null) { + this.authenticationConverter = new DPoPAuthenticationConverter(); + } + return this.authenticationConverter; + } + + private AuthenticationSuccessHandler getAuthenticationSuccessHandler() { + if (this.authenticationSuccessHandler == null) { + this.authenticationSuccessHandler = (request, response, authentication) -> { + // No-op - will continue on filter chain + }; + } + return this.authenticationSuccessHandler; + } + + private AuthenticationFailureHandler getAuthenticationFailureHandler() { + if (this.authenticationFailureHandler == null) { + this.authenticationFailureHandler = new AuthenticationEntryPointFailureHandler( + new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED)); + } + return this.authenticationFailureHandler; + } + + private static final class DPoPRequestMatcher implements RequestMatcher { + + @Override + public boolean matches(HttpServletRequest request) { + String authorization = request.getHeader(HttpHeaders.AUTHORIZATION); + if (!StringUtils.hasText(authorization)) { + return false; + } + return StringUtils.startsWithIgnoreCase(authorization, OAuth2AccessToken.TokenType.DPOP.getValue()); + } + + } + + private static final class DPoPAuthenticationConverter implements AuthenticationConverter { + + private static final Pattern AUTHORIZATION_PATTERN = Pattern.compile("^DPoP (?[a-zA-Z0-9-._~+/]+=*)$", + Pattern.CASE_INSENSITIVE); + + @Override + public Authentication convert(HttpServletRequest request) { + List authorizationList = Collections.list(request.getHeaders(HttpHeaders.AUTHORIZATION)); + if (CollectionUtils.isEmpty(authorizationList)) { + return null; + } + if (authorizationList.size() != 1) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, + "Found multiple Authorization headers.", null); + throw new OAuth2AuthenticationException(error); + } + String authorization = authorizationList.get(0); + if (!StringUtils.startsWithIgnoreCase(authorization, OAuth2AccessToken.TokenType.DPOP.getValue())) { + return null; + } + Matcher matcher = AUTHORIZATION_PATTERN.matcher(authorization); + if (!matcher.matches()) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, "DPoP access token is malformed.", + null); + throw new OAuth2AuthenticationException(error); + } + String accessToken = matcher.group("token"); + List dPoPProofList = Collections + .list(request.getHeaders(OAuth2AccessToken.TokenType.DPOP.getValue())); + if (CollectionUtils.isEmpty(dPoPProofList) || dPoPProofList.size() != 1) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, + "DPoP proof is missing or invalid.", null); + throw new OAuth2AuthenticationException(error); + } + String dPoPProof = dPoPProofList.get(0); + return new DPoPAuthenticationToken(accessToken, dPoPProof, request.getMethod(), + request.getRequestURL().toString()); + } + + } + +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java index 31a8c265a0..e9a425d46d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -152,6 +152,8 @@ public final class OAuth2ResourceServerConfigurer dPoPAuthenticationConfigurer = new DPoPAuthenticationConfigurer<>(); + private AuthenticationManagerResolver authenticationManagerResolver; private BearerTokenResolver bearerTokenResolver; @@ -283,6 +285,7 @@ public void configure(H http) { filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); filter = postProcess(filter); http.addFilter(filter); + this.dPoPAuthenticationConfigurer.configure(http); } private void validateConfiguration() { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurerTests.java new file mode 100644 index 0000000000..0011728624 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/DPoPAuthenticationConfigurerTests.java @@ -0,0 +1,279 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers.oauth2.server.resource; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.PublicKey; +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.config.test.SpringTestContextExtension; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JwsHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoderParameters; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Tests for {@link DPoPAuthenticationConfigurer}. + * + * @author Joe Grandja + */ +@ExtendWith(SpringTestContextExtension.class) +public class DPoPAuthenticationConfigurerTests { + + private static final RSAPublicKey PROVIDER_RSA_PUBLIC_KEY = TestKeys.DEFAULT_PUBLIC_KEY; + + private static final RSAPrivateKey PROVIDER_RSA_PRIVATE_KEY = TestKeys.DEFAULT_PRIVATE_KEY; + + private static final ECPublicKey CLIENT_EC_PUBLIC_KEY = (ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(); + + private static final ECPrivateKey CLIENT_EC_PRIVATE_KEY = (ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate(); + + private static NimbusJwtEncoder providerJwtEncoder; + + private static NimbusJwtEncoder clientJwtEncoder; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Autowired + private MockMvc mvc; + + @BeforeAll + public static void init() { + RSAKey providerRsaKey = TestJwks.jwk(PROVIDER_RSA_PUBLIC_KEY, PROVIDER_RSA_PRIVATE_KEY).build(); + JWKSource providerJwkSource = (jwkSelector, securityContext) -> jwkSelector + .select(new JWKSet(providerRsaKey)); + providerJwtEncoder = new NimbusJwtEncoder(providerJwkSource); + ECKey clientEcKey = TestJwks.jwk(CLIENT_EC_PUBLIC_KEY, CLIENT_EC_PRIVATE_KEY).build(); + JWKSource clientJwkSource = (jwkSelector, securityContext) -> jwkSelector + .select(new JWKSet(clientEcKey)); + clientJwtEncoder = new NimbusJwtEncoder(clientJwkSource); + } + + @Test + public void requestWhenDPoPAndBearerAuthenticationThenUnauthorized() throws Exception { + this.spring.register(SecurityConfig.class, ResourceEndpoints.class).autowire(); + Set scope = Collections.singleton("resource1.read"); + String accessToken = generateAccessToken(scope, CLIENT_EC_PUBLIC_KEY); + String dPoPProof = generateDPoPProof(HttpMethod.GET.name(), "http://localhost/resource1", accessToken); + // @formatter:off + this.mvc.perform(get("/resource1") + .header(HttpHeaders.AUTHORIZATION, "DPoP " + accessToken) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken) + .header("DPoP", dPoPProof)) + .andExpect(status().isUnauthorized()); + // @formatter:on + } + + @Test + public void requestWhenDPoPAccessTokenMalformedThenUnauthorized() throws Exception { + this.spring.register(SecurityConfig.class, ResourceEndpoints.class).autowire(); + Set scope = Collections.singleton("resource1.read"); + String accessToken = generateAccessToken(scope, CLIENT_EC_PUBLIC_KEY); + String dPoPProof = generateDPoPProof(HttpMethod.GET.name(), "http://localhost/resource1", accessToken); + // @formatter:off + this.mvc.perform(get("/resource1") + .header(HttpHeaders.AUTHORIZATION, "DPoP " + accessToken + " m a l f o r m e d ") + .header("DPoP", dPoPProof)) + .andExpect(status().isUnauthorized()); + // @formatter:on + } + + @Test + public void requestWhenMultipleDPoPProofsThenUnauthorized() throws Exception { + this.spring.register(SecurityConfig.class, ResourceEndpoints.class).autowire(); + Set scope = Collections.singleton("resource1.read"); + String accessToken = generateAccessToken(scope, CLIENT_EC_PUBLIC_KEY); + String dPoPProof = generateDPoPProof(HttpMethod.GET.name(), "http://localhost/resource1", accessToken); + // @formatter:off + this.mvc.perform(get("/resource1") + .header(HttpHeaders.AUTHORIZATION, "DPoP " + accessToken) + .header("DPoP", dPoPProof) + .header("DPoP", dPoPProof)) + .andExpect(status().isUnauthorized()); + // @formatter:on + } + + @Test + public void requestWhenDPoPAuthenticationValidThenAccessed() throws Exception { + this.spring.register(SecurityConfig.class, ResourceEndpoints.class).autowire(); + Set scope = Collections.singleton("resource1.read"); + String accessToken = generateAccessToken(scope, CLIENT_EC_PUBLIC_KEY); + String dPoPProof = generateDPoPProof(HttpMethod.GET.name(), "http://localhost/resource1", accessToken); + // @formatter:off + this.mvc.perform(get("/resource1") + .header(HttpHeaders.AUTHORIZATION, "DPoP " + accessToken) + .header("DPoP", dPoPProof)) + .andExpect(status().isOk()) + .andExpect(content().string("resource1")); + // @formatter:on + } + + private static String generateAccessToken(Set scope, PublicKey clientPublicKey) { + Map jktClaim = null; + if (clientPublicKey != null) { + try { + String sha256Thumbprint = computeSHA256(clientPublicKey); + jktClaim = new HashMap<>(); + jktClaim.put("jkt", sha256Thumbprint); + } + catch (Exception ignored) { + } + } + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() + .issuer("https://provider.com") + .subject("subject") + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .id(UUID.randomUUID().toString()) + .claim(OAuth2ParameterNames.SCOPE, scope); + if (jktClaim != null) { + claimsBuilder.claim("cnf", jktClaim); // Bind client public key + } + // @formatter:on + Jwt jwt = providerJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claimsBuilder.build())); + return jwt.getTokenValue(); + } + + private static String generateDPoPProof(String method, String resourceUri, String accessToken) throws Exception { + // @formatter:off + Map publicJwk = TestJwks.jwk(CLIENT_EC_PUBLIC_KEY, CLIENT_EC_PRIVATE_KEY) + .build() + .toPublicJWK() + .toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.ES256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) + .claim("ath", computeSHA256(accessToken)) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + Jwt jwt = clientJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + return jwt.getTokenValue(); + } + + private static String computeSHA256(String value) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + private static String computeSHA256(PublicKey publicKey) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(publicKey.getEncoded()); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + @Configuration + @EnableWebSecurity + @EnableWebMvc + static class SecurityConfig { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> + authorize + .requestMatchers("/resource1").hasAnyAuthority("SCOPE_resource1.read", "SCOPE_resource1.write") + .requestMatchers("/resource2").hasAnyAuthority("SCOPE_resource2.read", "SCOPE_resource2.write") + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .jwt(Customizer.withDefaults())); + // @formatter:on + return http.build(); + } + + @Bean + NimbusJwtDecoder jwtDecoder() { + return NimbusJwtDecoder.withPublicKey(PROVIDER_RSA_PUBLIC_KEY).build(); + } + + } + + @RestController + static class ResourceEndpoints { + + @RequestMapping(value = "/resource1", method = { RequestMethod.GET, RequestMethod.POST }) + String resource1() { + return "resource1"; + } + + @RequestMapping(value = "/resource2", method = { RequestMethod.GET, RequestMethod.POST }) + String resource2() { + return "resource2"; + } + + } + +} diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java index 47587435bc..004c65350a 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -139,6 +139,15 @@ public final class OAuth2ErrorCodes { */ public static final String INVALID_REDIRECT_URI = "invalid_redirect_uri"; + /** + * {@code invalid_dpop_proof} - The DPoP Proof JWT is invalid. + * + * @since 6.5 + * @see RFC-9449 - OAuth 2.0 + * Demonstrating Proof of Possession (DPoP) + */ + public static final String INVALID_DPOP_PROOF = "invalid_dpop_proof"; + private OAuth2ErrorCodes() { } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofContext.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofContext.java new file mode 100644 index 0000000000..16a5947cf5 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofContext.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.net.URI; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.util.Assert; + +/** + * @author Joe Grandja + * @since 6.5 + * @see DPoPProofJwtDecoderFactory + */ +public final class DPoPProofContext { + + private final String dPoPProof; + + private final String method; + + private final String targetUri; + + private final OAuth2Token accessToken; + + private DPoPProofContext(String dPoPProof, String method, String targetUri, @Nullable OAuth2Token accessToken) { + this.dPoPProof = dPoPProof; + this.method = method; + this.targetUri = targetUri; + this.accessToken = accessToken; + } + + public String getDPoPProof() { + return this.dPoPProof; + } + + public String getMethod() { + return this.method; + } + + public String getTargetUri() { + return this.targetUri; + } + + @SuppressWarnings("unchecked") + @Nullable + public T getAccessToken() { + return (T) this.accessToken; + } + + public static Builder withDPoPProof(String dPoPProof) { + return new Builder(dPoPProof); + } + + public static final class Builder { + + private String dPoPProof; + + private String method; + + private String targetUri; + + private OAuth2Token accessToken; + + private Builder(String dPoPProof) { + Assert.hasText(dPoPProof, "dPoPProof cannot be empty"); + this.dPoPProof = dPoPProof; + } + + public Builder method(String method) { + this.method = method; + return this; + } + + public Builder targetUri(String targetUri) { + this.targetUri = targetUri; + return this; + } + + public Builder accessToken(OAuth2Token accessToken) { + this.accessToken = accessToken; + return this; + } + + public DPoPProofContext build() { + validate(); + return new DPoPProofContext(this.dPoPProof, this.method, this.targetUri, this.accessToken); + } + + private void validate() { + Assert.hasText(this.method, "method cannot be empty"); + Assert.hasText(this.targetUri, "targetUri cannot be empty"); + if (!"GET".equals(this.method) && !"HEAD".equals(this.method) && !"POST".equals(this.method) + && !"PUT".equals(this.method) && !"PATCH".equals(this.method) && !"DELETE".equals(this.method) + && !"OPTIONS".equals(this.method) && !"TRACE".equals(this.method)) { + throw new IllegalArgumentException("method is invalid"); + } + URI uri; + try { + uri = new URI(this.targetUri); + uri.toURL(); + } + catch (Exception ex) { + throw new IllegalArgumentException("targetUri must be a valid URL", ex); + } + if (uri.getQuery() != null || uri.getFragment() != null) { + throw new IllegalArgumentException("targetUri cannot contain query or fragment parts"); + } + } + + } + +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory.java new file mode 100644 index 0000000000..32a5913526 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory.java @@ -0,0 +1,203 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.JOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; + +import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * @author Joe Grandja + * @since 6.5 + * @see DPoPProofContext + */ +public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory { + + private static final JOSEObjectTypeVerifier DPOP_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>( + new JOSEObjectType("dpop+jwt")); + + public static final Function> DEFAULT_JWT_VALIDATOR_FACTORY = defaultJwtValidatorFactory(); + + private Function> jwtValidatorFactory = DEFAULT_JWT_VALIDATOR_FACTORY; + + @Override + public JwtDecoder createDecoder(DPoPProofContext dPoPProofContext) { + Assert.notNull(dPoPProofContext, "dPoPProofContext cannot be null"); + NimbusJwtDecoder jwtDecoder = buildDecoder(); + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(dPoPProofContext)); + return jwtDecoder; + } + + public void setJwtValidatorFactory(Function> jwtValidatorFactory) { + Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null"); + this.jwtValidatorFactory = jwtValidatorFactory; + } + + private static NimbusJwtDecoder buildDecoder() { + ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); + jwtProcessor.setJWSTypeVerifier(DPOP_TYPE_VERIFIER); + jwtProcessor.setJWSKeySelector(jwsKeySelector()); + // Override the default Nimbus claims set verifier and use jwtValidatorFactory for + // claims validation + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); + return new NimbusJwtDecoder(jwtProcessor); + } + + private static JWSKeySelector jwsKeySelector() { + return (header, context) -> { + JWSAlgorithm algorithm = header.getAlgorithm(); + if (!JWSAlgorithm.Family.RSA.contains(algorithm) && !JWSAlgorithm.Family.EC.contains(algorithm)) { + throw new BadJwtException("Unsupported alg parameter in JWS Header: " + algorithm.getName()); + } + + JWK jwk = header.getJWK(); + if (jwk == null) { + throw new BadJwtException("Missing jwk parameter in JWS Header."); + } + if (jwk.isPrivate()) { + throw new BadJwtException("Invalid jwk parameter in JWS Header."); + } + + try { + if (JWSAlgorithm.Family.RSA.contains(algorithm) && jwk instanceof RSAKey rsaKey) { + return Collections.singletonList(rsaKey.toRSAPublicKey()); + } + else if (JWSAlgorithm.Family.EC.contains(algorithm) && jwk instanceof ECKey ecKey) { + return Collections.singletonList(ecKey.toECPublicKey()); + } + } + catch (JOSEException ex) { + throw new BadJwtException("Invalid jwk parameter in JWS Header."); + } + + throw new BadJwtException("Invalid alg / jwk parameter in JWS Header: alg=" + algorithm.getName() + + ", jwk.kty=" + jwk.getKeyType().getValue()); + }; + } + + private static Function> defaultJwtValidatorFactory() { + return (context) -> new DelegatingOAuth2TokenValidator<>( + new JwtClaimValidator<>("htm", context.getMethod()::equals), + new JwtClaimValidator<>("htu", context.getTargetUri()::equals), new JtiClaimValidator(), + new IatClaimValidator()); + } + + private static final class JtiClaimValidator implements OAuth2TokenValidator { + + private static final Map jtiCache = new ConcurrentHashMap<>(); + + @Override + public OAuth2TokenValidatorResult validate(Jwt jwt) { + Assert.notNull(jwt, "DPoP proof jwt cannot be null"); + String jti = jwt.getId(); + if (!StringUtils.hasText(jti)) { + OAuth2Error error = createOAuth2Error("jti claim is required."); + return OAuth2TokenValidatorResult.failure(error); + } + + // Enforce single-use to protect against DPoP proof replay + String jtiHash; + try { + jtiHash = computeSHA256(jti); + } + catch (Exception ex) { + OAuth2Error error = createOAuth2Error("jti claim is invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + Instant now = Instant.now(Clock.systemUTC()); + if ((jtiCache.putIfAbsent(jtiHash, now.toEpochMilli())) != null) { + // Already used + OAuth2Error error = createOAuth2Error("jti claim is invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + return OAuth2TokenValidatorResult.success(); + } + + private static OAuth2Error createOAuth2Error(String reason) { + return new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, reason, null); + } + + private static String computeSHA256(String value) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + } + + private static final class IatClaimValidator implements OAuth2TokenValidator { + + private final Duration clockSkew = Duration.ofSeconds(60); + + private final Clock clock = Clock.systemUTC(); + + @Override + public OAuth2TokenValidatorResult validate(Jwt jwt) { + Assert.notNull(jwt, "DPoP proof jwt cannot be null"); + Instant issuedAt = jwt.getIssuedAt(); + if (issuedAt == null) { + OAuth2Error error = createOAuth2Error("iat claim is required."); + return OAuth2TokenValidatorResult.failure(error); + } + + // Check time window of validity + Instant now = Instant.now(this.clock); + Instant notBefore = now.minus(this.clockSkew); + Instant notAfter = now.plus(this.clockSkew); + if (issuedAt.isBefore(notBefore) || issuedAt.isAfter(notAfter)) { + OAuth2Error error = createOAuth2Error("iat claim is invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + return OAuth2TokenValidatorResult.success(); + } + + private static OAuth2Error createOAuth2Error(String reason) { + return new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, reason, null); + } + + } + +} diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactoryTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactoryTests.java new file mode 100644 index 0000000000..6f107fa675 --- /dev/null +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactoryTests.java @@ -0,0 +1,451 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; + +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link DPoPProofJwtDecoderFactory}. + * + * @author Joe Grandja + */ +public class DPoPProofJwtDecoderFactoryTests { + + private JWKSource jwkSource; + + private NimbusJwtEncoder jwtEncoder; + + private DPoPProofJwtDecoderFactory jwtDecoderFactory = new DPoPProofJwtDecoderFactory(); + + @BeforeEach + public void setUp() { + this.jwkSource = mock(JWKSource.class); + this.jwtEncoder = new NimbusJwtEncoder(this.jwkSource); + } + + @Test + public void setJwtValidatorFactoryWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.jwtDecoderFactory.setJwtValidatorFactory(null)) + .withMessage("jwtValidatorFactory cannot be null"); + } + + @Test + public void createDecoderWhenContextNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.jwtDecoderFactory.createDecoder(null)) + .withMessage("dPoPProofContext cannot be null"); + } + + @Test + public void decodeWhenJoseTypeInvalidThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("invalid-type") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("JOSE header typ (type) invalid-type not allowed"); + } + + @Test + public void decodeWhenJwkMissingThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") +// .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("Missing jwk parameter in JWS Header."); + } + + @Test + public void decodeWhenMethodInvalidThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method("POST") // Mismatch + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("The htm claim is not valid"); + } + + @Test + public void decodeWhenTargetUriInvalidThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri("https://resource2") // Mismatch + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("The htu claim is not valid"); + } + + @Test + public void decodeWhenJtiMissingThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) +// .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("jti claim is required"); + } + + @Test + public void decodeWhenJtiAlreadyUsedThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + jwtDecoder.decode(dPoPProofContext.getDPoPProof()); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("jti claim is invalid"); + } + + @Test + public void decodeWhenIatMissingThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() +// .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("iat claim is required"); + } + + @Test + public void decodeWhenIatBeforeTimeWindowThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + Instant issuedAt = Instant.now().minus(Duration.ofSeconds(65)); // now minus 65 seconds + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(issuedAt) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("iat claim is invalid"); + } + + @Test + public void decodeWhenIatAfterTimeWindowThenThrowBadJwtException() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + Instant issuedAt = Instant.now().plus(Duration.ofSeconds(65)); // now plus 65 seconds + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(issuedAt) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(dPoPProofContext.getDPoPProof())) + .withMessageContaining("iat claim is invalid"); + } + + @Test + public void decodeWhenDPoPProofValidThenDecoded() throws Exception { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); + + String method = "GET"; + String targetUri = "https://resource1"; + + // @formatter:off + Map publicJwk = rsaJwk.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", targetUri) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + // @formatter:off + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPProof.getTokenValue()) + .method(method) + .targetUri(targetUri) + .build(); + // @formatter:on + + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(dPoPProofContext); + jwtDecoder.decode(dPoPProof.getTokenValue()); + } + +} diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProvider.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProvider.java new file mode 100644 index 0000000000..b26cb754c7 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProvider.java @@ -0,0 +1,273 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.authentication; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.PublicKey; +import java.time.Instant; +import java.util.Base64; +import java.util.Map; +import java.util.function.Function; + +import com.nimbusds.jose.jwk.AsymmetricJWK; +import com.nimbusds.jose.jwk.JWK; + +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.ClaimAccessor; +import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import org.springframework.security.oauth2.jwt.DPoPProofContext; +import org.springframework.security.oauth2.jwt.DPoPProofJwtDecoderFactory; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Joe Grandja + * @since 6.5 + * @see DPoPAuthenticationToken + * @see DPoPProofJwtDecoderFactory + */ +public final class DPoPAuthenticationProvider implements AuthenticationProvider { + + private final AuthenticationManager tokenAuthenticationManager; + + private JwtDecoderFactory dPoPProofVerifierFactory; + + public DPoPAuthenticationProvider(AuthenticationManager tokenAuthenticationManager) { + Assert.notNull(tokenAuthenticationManager, "tokenAuthenticationManager cannot be null"); + this.tokenAuthenticationManager = tokenAuthenticationManager; + Function> jwtValidatorFactory = ( + context) -> new DelegatingOAuth2TokenValidator<>( + // Use default validators + DPoPProofJwtDecoderFactory.DEFAULT_JWT_VALIDATOR_FACTORY.apply(context), + // Add custom validators + new AthClaimValidator(context.getAccessToken()), + new JwkThumbprintValidator(context.getAccessToken())); + DPoPProofJwtDecoderFactory dPoPProofJwtDecoderFactory = new DPoPProofJwtDecoderFactory(); + dPoPProofJwtDecoderFactory.setJwtValidatorFactory(jwtValidatorFactory); + this.dPoPProofVerifierFactory = dPoPProofJwtDecoderFactory; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + DPoPAuthenticationToken dPoPAuthenticationToken = (DPoPAuthenticationToken) authentication; + + BearerTokenAuthenticationToken accessTokenAuthenticationRequest = new BearerTokenAuthenticationToken( + dPoPAuthenticationToken.getAccessToken()); + Authentication accessTokenAuthenticationResult = this.tokenAuthenticationManager + .authenticate(accessTokenAuthenticationRequest); + + AbstractOAuth2TokenAuthenticationToken accessTokenAuthentication = null; + if (accessTokenAuthenticationResult instanceof AbstractOAuth2TokenAuthenticationToken) { + accessTokenAuthentication = (AbstractOAuth2TokenAuthenticationToken) accessTokenAuthenticationResult; + } + if (accessTokenAuthentication == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, + "Unable to authenticate the DPoP-bound access token.", null); + throw new OAuth2AuthenticationException(error); + } + + OAuth2AccessTokenClaims accessToken = new OAuth2AccessTokenClaims(accessTokenAuthentication.getToken(), + accessTokenAuthentication.getTokenAttributes()); + + DPoPProofContext dPoPProofContext = DPoPProofContext.withDPoPProof(dPoPAuthenticationToken.getDPoPProof()) + .accessToken(accessToken) + .method(dPoPAuthenticationToken.getMethod()) + .targetUri(dPoPAuthenticationToken.getResourceUri()) + .build(); + JwtDecoder dPoPProofVerifier = this.dPoPProofVerifierFactory.createDecoder(dPoPProofContext); + + try { + dPoPProofVerifier.decode(dPoPProofContext.getDPoPProof()); + } + catch (JwtException ex) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF); + throw new OAuth2AuthenticationException(error, ex); + } + + return accessTokenAuthenticationResult; + } + + @Override + public boolean supports(Class authentication) { + return DPoPAuthenticationToken.class.isAssignableFrom(authentication); + } + + public void setDPoPProofVerifierFactory(JwtDecoderFactory dPoPProofVerifierFactory) { + Assert.notNull(dPoPProofVerifierFactory, "dPoPProofVerifierFactory cannot be null"); + this.dPoPProofVerifierFactory = dPoPProofVerifierFactory; + } + + private static final class AthClaimValidator implements OAuth2TokenValidator { + + private final OAuth2AccessTokenClaims accessToken; + + private AthClaimValidator(OAuth2AccessTokenClaims accessToken) { + Assert.notNull(accessToken, "accessToken cannot be null"); + this.accessToken = accessToken; + } + + @Override + public OAuth2TokenValidatorResult validate(Jwt jwt) { + Assert.notNull(jwt, "DPoP proof jwt cannot be null"); + String accessTokenHashClaim = jwt.getClaimAsString("ath"); + if (!StringUtils.hasText(accessTokenHashClaim)) { + OAuth2Error error = createOAuth2Error("ath claim is required."); + return OAuth2TokenValidatorResult.failure(error); + } + + String accessTokenHash; + try { + accessTokenHash = computeSHA256(this.accessToken.getTokenValue()); + } + catch (Exception ex) { + OAuth2Error error = createOAuth2Error("Failed to compute SHA-256 Thumbprint for access token."); + return OAuth2TokenValidatorResult.failure(error); + } + if (!accessTokenHashClaim.equals(accessTokenHash)) { + OAuth2Error error = createOAuth2Error("ath claim is invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + return OAuth2TokenValidatorResult.success(); + } + + private static OAuth2Error createOAuth2Error(String reason) { + return new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, reason, null); + } + + private static String computeSHA256(String value) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + } + + private static final class JwkThumbprintValidator implements OAuth2TokenValidator { + + private final OAuth2AccessTokenClaims accessToken; + + private JwkThumbprintValidator(OAuth2AccessTokenClaims accessToken) { + Assert.notNull(accessToken, "accessToken cannot be null"); + this.accessToken = accessToken; + } + + @Override + public OAuth2TokenValidatorResult validate(Jwt jwt) { + Assert.notNull(jwt, "DPoP proof jwt cannot be null"); + String jwkThumbprintClaim = null; + Map confirmationMethodClaim = this.accessToken.getClaimAsMap("cnf"); + if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) { + jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt"); + } + if (jwkThumbprintClaim == null) { + OAuth2Error error = createOAuth2Error("jkt claim is required."); + return OAuth2TokenValidatorResult.failure(error); + } + + PublicKey publicKey = null; + @SuppressWarnings("unchecked") + Map jwkJson = (Map) jwt.getHeaders().get("jwk"); + try { + JWK jwk = JWK.parse(jwkJson); + if (jwk instanceof AsymmetricJWK) { + publicKey = ((AsymmetricJWK) jwk).toPublicKey(); + } + } + catch (Exception ignored) { + } + if (publicKey == null) { + OAuth2Error error = createOAuth2Error("jwk header is missing or invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + + String jwkThumbprint; + try { + jwkThumbprint = computeSHA256(publicKey); + } + catch (Exception ex) { + OAuth2Error error = createOAuth2Error("Failed to compute SHA-256 Thumbprint for jwk."); + return OAuth2TokenValidatorResult.failure(error); + } + + if (!jwkThumbprintClaim.equals(jwkThumbprint)) { + OAuth2Error error = createOAuth2Error("jkt claim is invalid."); + return OAuth2TokenValidatorResult.failure(error); + } + return OAuth2TokenValidatorResult.success(); + } + + private static OAuth2Error createOAuth2Error(String reason) { + return new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, reason, null); + } + + private static String computeSHA256(PublicKey publicKey) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(publicKey.getEncoded()); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + } + + private static final class OAuth2AccessTokenClaims implements OAuth2Token, ClaimAccessor { + + private final OAuth2Token accessToken; + + private final Map claims; + + private OAuth2AccessTokenClaims(OAuth2Token accessToken, Map claims) { + this.accessToken = accessToken; + this.claims = claims; + } + + @Override + public String getTokenValue() { + return this.accessToken.getTokenValue(); + } + + @Override + public Instant getIssuedAt() { + return this.accessToken.getIssuedAt(); + } + + @Override + public Instant getExpiresAt() { + return this.accessToken.getExpiresAt(); + } + + @Override + public Map getClaims() { + return this.claims; + } + + } + +} diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationToken.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationToken.java new file mode 100644 index 0000000000..0abca69706 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationToken.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.authentication; + +import java.io.Serial; +import java.util.Collections; + +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.util.Assert; + +/** + * @author Joe Grandja + * @since 6.5 + * @see DPoPAuthenticationProvider + */ +public class DPoPAuthenticationToken extends AbstractAuthenticationToken { + + @Serial + private static final long serialVersionUID = 5481690438914686216L; + + private final String accessToken; + + private final String dPoPProof; + + private final String method; + + private final String resourceUri; + + public DPoPAuthenticationToken(String accessToken, String dPoPProof, String method, String resourceUri) { + super(Collections.emptyList()); + Assert.hasText(accessToken, "accessToken cannot be empty"); + Assert.hasText(dPoPProof, "dPoPProof cannot be empty"); + Assert.hasText(method, "method cannot be empty"); + Assert.hasText(resourceUri, "resourceUri cannot be empty"); + this.accessToken = accessToken; + this.dPoPProof = dPoPProof; + this.method = method; + this.resourceUri = resourceUri; + } + + @Override + public Object getPrincipal() { + return getAccessToken(); + } + + @Override + public Object getCredentials() { + return getAccessToken(); + } + + public String getAccessToken() { + return this.accessToken; + } + + public String getDPoPProof() { + return this.dPoPProof; + } + + public String getMethod() { + return this.method; + } + + public String getResourceUri() { + return this.resourceUri; + } + +} diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProviderTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProviderTests.java new file mode 100644 index 0000000000..08aec38900 --- /dev/null +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/DPoPAuthenticationProviderTests.java @@ -0,0 +1,331 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.resource.authentication; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.PublicKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JwsHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoderParameters; +import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link DPoPAuthenticationProvider}. + * + * @author Joe Grandja + */ +public class DPoPAuthenticationProviderTests { + + private NimbusJwtEncoder accessTokenJwtEncoder; + + private NimbusJwtEncoder dPoPProofJwtEncoder; + + private AuthenticationManager tokenAuthenticationManager; + + private DPoPAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + JWKSource jwkSource = (jwkSelector, securityContext) -> jwkSelector + .select(new JWKSet(TestJwks.DEFAULT_EC_JWK)); + this.accessTokenJwtEncoder = new NimbusJwtEncoder(jwkSource); + jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(TestJwks.DEFAULT_RSA_JWK)); + this.dPoPProofJwtEncoder = new NimbusJwtEncoder(jwkSource); + this.tokenAuthenticationManager = mock(AuthenticationManager.class); + this.authenticationProvider = new DPoPAuthenticationProvider(this.tokenAuthenticationManager); + } + + @Test + public void constructorWhenTokenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new DPoPAuthenticationProvider(null)) + .withMessage("tokenAuthenticationManager cannot be null"); + } + + @Test + public void supportsWhenDPoPAuthenticationTokenThenReturnsTrue() { + assertThat(this.authenticationProvider.supports(DPoPAuthenticationToken.class)).isTrue(); + } + + @Test + public void setDPoPProofVerifierFactoryWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setDPoPProofVerifierFactory(null)) + .withMessage("dPoPProofVerifierFactory cannot be null"); + } + + @Test + public void authenticateWhenUnableToAuthenticateAccessTokenThenThrowOAuth2AuthenticationException() { + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken("access-token", "dpop-proof", + "GET", "https://resource1"); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(dPoPAuthenticationToken)) + .satisfies((ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(ex.getError().getDescription()) + .isEqualTo("Unable to authenticate the DPoP-bound access token."); + }); + } + + @Test + public void authenticateWhenAthMissingThenThrowOAuth2AuthenticationException() { + Jwt accessToken = generateAccessToken(); + JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); + given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); + + String method = "GET"; + String resourceUri = "https://resource1"; + + // @formatter:off + Map publicJwk = TestJwks.DEFAULT_RSA_JWK.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) +// .claim("ath", computeSHA256(accessToken.getTokenValue())) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.dPoPProofJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken(accessToken.getTokenValue(), + dPoPProof.getTokenValue(), method, resourceUri); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(dPoPAuthenticationToken)) + .satisfies((ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_DPOP_PROOF); + assertThat(ex.getMessage()).contains("ath claim is required"); + }); + } + + @Test + public void authenticateWhenAthDoesNotMatchThenThrowOAuth2AuthenticationException() throws Exception { + Jwt accessToken = generateAccessToken(); + JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); + given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); + + String method = "GET"; + String resourceUri = "https://resource1"; + + // @formatter:off + Map publicJwk = TestJwks.DEFAULT_RSA_JWK.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) + .claim("ath", computeSHA256(accessToken.getTokenValue()) + "-mismatch") + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.dPoPProofJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken(accessToken.getTokenValue(), + dPoPProof.getTokenValue(), method, resourceUri); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(dPoPAuthenticationToken)) + .satisfies((ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_DPOP_PROOF); + assertThat(ex.getMessage()).contains("ath claim is invalid"); + }); + } + + @Test + public void authenticateWhenJktMissingThenThrowOAuth2AuthenticationException() throws Exception { + Jwt accessToken = generateAccessToken(null); // jkt claim is not added + JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); + given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); + + String method = "GET"; + String resourceUri = "https://resource1"; + + // @formatter:off + Map publicJwk = TestJwks.DEFAULT_RSA_JWK.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) + .claim("ath", computeSHA256(accessToken.getTokenValue())) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.dPoPProofJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken(accessToken.getTokenValue(), + dPoPProof.getTokenValue(), method, resourceUri); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(dPoPAuthenticationToken)) + .satisfies((ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_DPOP_PROOF); + assertThat(ex.getMessage()).contains("jkt claim is required"); + }); + } + + @Test + public void authenticateWhenJktDoesNotMatchThenThrowOAuth2AuthenticationException() throws Exception { + // Use different client public key + Jwt accessToken = generateAccessToken(TestKeys.DEFAULT_EC_KEY_PAIR.getPublic()); + JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); + given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); + + String method = "GET"; + String resourceUri = "https://resource1"; + + // @formatter:off + Map publicJwk = TestJwks.DEFAULT_RSA_JWK.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) + .claim("ath", computeSHA256(accessToken.getTokenValue())) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.dPoPProofJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken(accessToken.getTokenValue(), + dPoPProof.getTokenValue(), method, resourceUri); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(dPoPAuthenticationToken)) + .satisfies((ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_DPOP_PROOF); + assertThat(ex.getMessage()).contains("jkt claim is invalid"); + }); + } + + @Test + public void authenticateWhenDPoPProofValidThenSuccess() throws Exception { + Jwt accessToken = generateAccessToken(); + JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); + given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); + + String method = "GET"; + String resourceUri = "https://resource1"; + + // @formatter:off + Map publicJwk = TestJwks.DEFAULT_RSA_JWK.toPublicJWK().toJSONObject(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) + .type("dpop+jwt") + .jwk(publicJwk) + .build(); + JwtClaimsSet claims = JwtClaimsSet.builder() + .issuedAt(Instant.now()) + .claim("htm", method) + .claim("htu", resourceUri) + .claim("ath", computeSHA256(accessToken.getTokenValue())) + .id(UUID.randomUUID().toString()) + .build(); + // @formatter:on + + Jwt dPoPProof = this.dPoPProofJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); + + DPoPAuthenticationToken dPoPAuthenticationToken = new DPoPAuthenticationToken(accessToken.getTokenValue(), + dPoPProof.getTokenValue(), method, resourceUri); + assertThat(this.authenticationProvider.authenticate(dPoPAuthenticationToken)).isSameAs(jwtAuthenticationToken); + } + + private Jwt generateAccessToken() { + return generateAccessToken(TestKeys.DEFAULT_PUBLIC_KEY); + } + + private Jwt generateAccessToken(PublicKey clientPublicKey) { + Map jktClaim = null; + if (clientPublicKey != null) { + try { + String sha256Thumbprint = computeSHA256(clientPublicKey); + jktClaim = new HashMap<>(); + jktClaim.put("jkt", sha256Thumbprint); + } + catch (Exception ignored) { + } + } + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.ES256).build(); + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder() + .issuer("https://provider.com") + .subject("subject") + .issuedAt(issuedAt) + .expiresAt(expiresAt) + .id(UUID.randomUUID().toString()); + if (jktClaim != null) { + claimsBuilder.claim("cnf", jktClaim); // Bind client public key + } + // @formatter:on + return this.accessTokenJwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claimsBuilder.build())); + } + + private static String computeSHA256(String value) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + + private static String computeSHA256(PublicKey publicKey) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(publicKey.getEncoded()); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } + +}