From f57d4dd1cc84c1ceb398eac8eaff174768b9fbc0 Mon Sep 17 00:00:00 2001 From: Budlee Date: Tue, 5 May 2020 01:35:04 +0100 Subject: [PATCH] Fix for private jwt client to rebuild when expired (#221) The PrivateJwt assertion with a certificate is generated once This means that when a request for a new OAuth Token is made that the client assertion has expired and fails. This change fixes that by reubuilding the assersion for a private jwt when it has expired --- .../aad/msal4j/ClientApplicationBase.java | 8 +- .../msal4j/ConfidentialClientApplication.java | 40 ++++++-- .../aad/msal4j/PublicClientApplication.java | 20 ++-- .../aad/msal4j/TokenRequestExecutor.java | 19 ++-- .../ConfidentialClientApplicationTest.java | 95 +++++++++++++++++-- .../aad/msal4j/TestConfiguration.java | 5 - 6 files changed, 141 insertions(+), 46 deletions(-) diff --git a/src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java b/src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java index bee675c2..4406759c 100644 --- a/src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java +++ b/src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java @@ -13,7 +13,10 @@ import java.net.MalformedURLException; import java.net.Proxy; import java.net.URL; -import java.util.*; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.function.Consumer; @@ -28,7 +31,6 @@ abstract class ClientApplicationBase implements IClientApplicationBase { protected Logger log; - protected ClientAuthentication clientAuthentication; protected Authority authenticationAuthority; private ServiceBundle serviceBundle; @@ -80,6 +82,8 @@ abstract class ClientApplicationBase implements IClientApplicationBase { @Getter private AadInstanceDiscoveryResponse aadAadInstanceDiscoveryResponse; + protected abstract ClientAuthentication clientAuthentication(); + @Override public CompletableFuture acquireToken(AuthorizationCodeParameters parameters) { diff --git a/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java b/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java index a988cb5a..23672f33 100644 --- a/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java +++ b/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java @@ -12,6 +12,7 @@ import org.slf4j.LoggerFactory; import java.util.Collections; +import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -23,11 +24,15 @@ * Class to be used to acquire tokens for confidential client applications (Web Apps, Web APIs, * and daemon applications). * For details see {@link IConfidentialClientApplication} - * + *

* Conditionally thread-safe */ public class ConfidentialClientApplication extends ClientApplicationBase implements IConfidentialClientApplication { + private ClientAuthentication clientAuthentication; + private boolean clientCertAuthentication = false; + private ClientCertificate clientCertificate; + @Override public CompletableFuture acquireToken(ClientCredentialParameters parameters) { @@ -71,19 +76,37 @@ private void initClientAuthentication(IClientCredential clientCredential) { new ClientID(clientId()), new Secret(((ClientSecret) clientCredential).clientSecret())); } else if (clientCredential instanceof ClientCertificate) { - ClientAssertion clientAssertion = JwtHelper.buildJwt( - clientId(), - (ClientCertificate) clientCredential, - this.authenticationAuthority.selfSignedJwtAudience()); - - clientAuthentication = createClientAuthFromClientAssertion(clientAssertion); - } else if (clientCredential instanceof ClientAssertion){ + this.clientCertAuthentication = true; + this.clientCertificate = (ClientCertificate) clientCredential; + clientAuthentication = buildValidClientCertificateAuthority(); + } else if (clientCredential instanceof ClientAssertion) { clientAuthentication = createClientAuthFromClientAssertion((ClientAssertion) clientCredential); } else { throw new IllegalArgumentException("Unsupported client credential"); } } + @Override + protected ClientAuthentication clientAuthentication() { + if (clientCertAuthentication) { + final Date currentDateTime = new Date(System.currentTimeMillis()); + final Date expirationTime = ((PrivateKeyJWT) clientAuthentication).getJWTAuthenticationClaimsSet().getExpirationTime(); + if (expirationTime.before(currentDateTime)) { + //The asserted private jwt with the client certificate can expire so rebuild it when the + clientAuthentication = buildValidClientCertificateAuthority(); + } + } + return clientAuthentication; + } + + private ClientAuthentication buildValidClientCertificateAuthority() { + ClientAssertion clientAssertion = JwtHelper.buildJwt( + clientId(), + clientCertificate, + this.authenticationAuthority.selfSignedJwtAudience()); + return createClientAuthFromClientAssertion(clientAssertion); + } + private ClientAuthentication createClientAuthFromClientAssertion( final ClientAssertion clientAssertion) { try { @@ -102,7 +125,6 @@ private ClientAuthentication createClientAuthFromClientAssertion( * @param clientId Client ID (Application ID) of the application as registered * in the application registration portal (portal.azure.com) * @param clientCredential The client credential to use for token acquisition. - * * @return instance of Builder of ConfidentialClientApplication */ public static Builder builder(String clientId, IClientCredential clientCredential) { diff --git a/src/main/java/com/microsoft/aad/msal4j/PublicClientApplication.java b/src/main/java/com/microsoft/aad/msal4j/PublicClientApplication.java index cf795fb8..fe75ce53 100644 --- a/src/main/java/com/microsoft/aad/msal4j/PublicClientApplication.java +++ b/src/main/java/com/microsoft/aad/msal4j/PublicClientApplication.java @@ -3,11 +3,9 @@ package com.microsoft.aad.msal4j; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod; import com.nimbusds.oauth2.sdk.id.ClientID; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.experimental.Accessors; import org.slf4j.LoggerFactory; import java.util.concurrent.CompletableFuture; @@ -24,6 +22,8 @@ */ public class PublicClientApplication extends ClientApplicationBase implements IPublicClientApplication { + private final ClientAuthenticationPost clientAuthentication; + @Override public CompletableFuture acquireToken(UserNamePasswordParameters parameters) { @@ -96,17 +96,15 @@ public CompletableFuture acquireToken(InteractiveRequestP private PublicClientApplication(Builder builder) { super(builder); - + validateNotBlank("clientId", clientId()); log = LoggerFactory.getLogger(PublicClientApplication.class); - - initClientAuthentication(clientId()); + this.clientAuthentication = new ClientAuthenticationPost(ClientAuthenticationMethod.NONE, + new ClientID(clientId())); } - private void initClientAuthentication(String clientId) { - validateNotBlank("clientId", clientId); - - clientAuthentication = new ClientAuthenticationPost(ClientAuthenticationMethod.NONE, - new ClientID(clientId)); + @Override + protected ClientAuthentication clientAuthentication() { + return clientAuthentication; } /** diff --git a/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java b/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java index 8b98afb7..0e0dd3ec 100644 --- a/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java +++ b/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java @@ -3,12 +3,6 @@ package com.microsoft.aad.msal4j; -import java.io.IOException; -import java.net.MalformedURLException; -import java.util.Date; -import java.util.List; -import java.util.Map; - import com.nimbusds.oauth2.sdk.ParseException; import com.nimbusds.oauth2.sdk.SerializeException; import com.nimbusds.oauth2.sdk.http.CommonContentTypes; @@ -16,11 +10,16 @@ import com.nimbusds.oauth2.sdk.http.HTTPResponse; import com.nimbusds.oauth2.sdk.util.URLUtils; import com.nimbusds.openid.connect.sdk.token.OIDCTokens; +import lombok.AccessLevel; +import lombok.Getter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import lombok.AccessLevel; -import lombok.Getter; +import java.io.IOException; +import java.net.MalformedURLException; +import java.util.Date; +import java.util.List; +import java.util.Map; @Getter(AccessLevel.PACKAGE) class TokenRequestExecutor { @@ -61,8 +60,8 @@ OAuthHttpRequest createOauthHttpRequest() throws SerializeException, MalformedUR final Map> params = msalRequest.msalAuthorizationGrant().toParameters(); oauthHttpRequest.setQuery(URLUtils.serializeParameters(params)); - if (msalRequest.application().clientAuthentication != null) { - msalRequest.application().clientAuthentication.applyTo(oauthHttpRequest); + if (msalRequest.application().clientAuthentication() != null) { + msalRequest.application().clientAuthentication().applyTo(oauthHttpRequest); } return oauthHttpRequest; } diff --git a/src/test/java/com/microsoft/aad/msal4j/ConfidentialClientApplicationTest.java b/src/test/java/com/microsoft/aad/msal4j/ConfidentialClientApplicationTest.java index b7850e7a..d080e985 100644 --- a/src/test/java/com/microsoft/aad/msal4j/ConfidentialClientApplicationTest.java +++ b/src/test/java/com/microsoft/aad/msal4j/ConfidentialClientApplicationTest.java @@ -3,6 +3,14 @@ package com.microsoft.aad.msal4j; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT; import org.easymock.EasyMock; import org.powermock.api.easymock.PowerMock; import org.powermock.core.classloader.annotations.PowerMockIgnore; @@ -16,17 +24,20 @@ import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Collections; import java.util.Date; +import java.util.List; +import java.util.UUID; import java.util.concurrent.Future; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; @PowerMockIgnore({"javax.net.ssl.*"}) -@Test(groups = { "checkin" }) -@PrepareForTest({ ConfidentialClientApplication.class, - ClientCertificate.class, UserDiscoveryRequest.class }) +@Test(groups = {"checkin"}) +@PrepareForTest({ConfidentialClientApplication.class, + ClientCertificate.class, UserDiscoveryRequest.class, JwtHelper.class}) public class ConfidentialClientApplicationTest extends PowerMockTestCase { private ConfidentialClientApplication app = null; @@ -34,9 +45,9 @@ public class ConfidentialClientApplicationTest extends PowerMockTestCase { @Test public void testAcquireTokenAuthCode_ClientCredential() throws Exception { app = PowerMock.createPartialMock(ConfidentialClientApplication.class, - new String[] { "acquireTokenCommon" }, + new String[]{"acquireTokenCommon"}, ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, - ClientCredentialFactory.createFromSecret(TestConfiguration.AAD_CLIENT_SECRET)) + ClientCredentialFactory.createFromSecret(TestConfiguration.AAD_CLIENT_SECRET)) .authority(TestConfiguration.AAD_TENANT_ENDPOINT) ); @@ -78,10 +89,10 @@ public void testAcquireTokenAuthCode_KeyCredential() throws Exception { final X509Certificate cert = (X509Certificate) keystore .getCertificate(alias); - IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert); + IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert); app = PowerMock.createPartialMock(ConfidentialClientApplication.class, - new String[] { "acquireTokenCommon" }, + new String[]{"acquireTokenCommon"}, ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential) .authority(TestConfiguration.AAD_TENANT_ENDPOINT)); @@ -123,10 +134,10 @@ public void testAcquireToken_KeyCred() throws Exception { final X509Certificate cert = (X509Certificate) keystore .getCertificate(alias); - IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert); + IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert); app = PowerMock.createPartialMock(ConfidentialClientApplication.class, - new String[] { "acquireTokenCommon" }, + new String[]{"acquireTokenCommon"}, ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential) .authority(TestConfiguration.AAD_TENANT_ENDPOINT)); @@ -153,4 +164,70 @@ public void testAcquireToken_KeyCred() throws Exception { PowerMock.verifyAll(); PowerMock.resetAll(app); } + + @Test + public void testClientCertificateRebuildsWhenExpired() throws Exception { + final KeyStore keystore = KeyStore.getInstance("PKCS12", "SunJSSE"); + keystore.load( + new FileInputStream(this.getClass() + .getResource(TestConfiguration.AAD_CERTIFICATE_PATH) + .getFile()), + TestConfiguration.AAD_CERTIFICATE_PASSWORD.toCharArray()); + final String alias = keystore.aliases().nextElement(); + final PrivateKey key = (PrivateKey) keystore.getKey(alias, + TestConfiguration.AAD_CERTIFICATE_PASSWORD.toCharArray()); + final X509Certificate cert = (X509Certificate) keystore + .getCertificate(alias); + + ClientCertificate clientCredential = (ClientCertificate) ClientCredentialFactory.createFromCertificate(key, cert); + + PowerMock.mockStaticPartial(JwtHelper.class, new String[]{"buildJwt"}); + long jwtExperiationPeriodMilli = 2000; + ClientAssertion shortExperationJwt = buildShortJwt(TestConfiguration.AAD_CLIENT_ID, clientCredential, TestConfiguration.AAD_TENANT_ENDPOINT, jwtExperiationPeriodMilli); + PowerMock.expectPrivate( + JwtHelper.class, + "buildJwt", + EasyMock.isA(String.class), + EasyMock.isA(ClientCertificate.class), + EasyMock.isA(String.class)) + .andReturn(shortExperationJwt) + .times(2); // By this being called twice we ensure the client assertion is rebuilt once it has expired + + PowerMock.replay(JwtHelper.class); + app = ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential) + .authority(TestConfiguration.AAD_TENANT_ENDPOINT).build(); + Thread.sleep(jwtExperiationPeriodMilli + 1000); //Have to sleep to ensure that the time period has passed + final PrivateKeyJWT clientAuthentication = (PrivateKeyJWT) app.clientAuthentication(); + assertNotNull(clientAuthentication); + PowerMock.verifyAll(); + } + + private ClientAssertion buildShortJwt(String clientId, + ClientCertificate credential, + String jwtAudience, + long jwtExperiationPeriod) { + final long time = System.currentTimeMillis(); + final JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() + .audience(Collections.singletonList(jwtAudience)) + .issuer(clientId) + .jwtID(UUID.randomUUID().toString()) + .notBeforeTime(new Date(time)) + .expirationTime(new Date(time + jwtExperiationPeriod)) + .subject(clientId) + .build(); + SignedJWT jwt; + try { + List certs = new ArrayList<>(); + certs.add(new Base64(credential.publicCertificate())); + JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.RS256); + builder.x509CertChain(certs); + builder.x509CertThumbprint(new Base64URL(credential.publicCertificateHash())); + jwt = new SignedJWT(builder.build(), claimsSet); + final RSASSASigner signer = new RSASSASigner(credential.key()); + jwt.sign(signer); + } catch (final Exception e) { + throw new MsalClientException(e); + } + return new ClientAssertion(jwt.serialize()); + } } diff --git a/src/test/java/com/microsoft/aad/msal4j/TestConfiguration.java b/src/test/java/com/microsoft/aad/msal4j/TestConfiguration.java index d56405ac..a13e96f8 100644 --- a/src/test/java/com/microsoft/aad/msal4j/TestConfiguration.java +++ b/src/test/java/com/microsoft/aad/msal4j/TestConfiguration.java @@ -20,7 +20,6 @@ public final class TestConfiguration { public final static String AAD_TOKEN_SUCCESS_FILE = "/token.xml"; public final static String AAD_CERTIFICATE_PASSWORD = "password"; public final static String AAD_DEFAULT_REDIRECT_URI = "https://non_existing_uri.windows.com/"; - public final static String AAD_REDIRECT_URI_FOR_CONFIDENTIAL_CLIENT = "https://non_existing_uri_for_confidential_client.com/"; public final static String AAD_COMMON_AUTHORITY = "https://login.microsoftonline.com/common/"; public final static String ADFS_HOST_NAME = "fs.ade2eadfs30.com"; @@ -39,16 +38,12 @@ public final class TestConfiguration { public final static String B2C_AUTHORITY_CUSTOM_PORT = "https://login.microsoftonline.in:444/tfp/tenant/policy"; public final static String B2C_AUTHORITY_CUSTOM_PORT_TAIL_SLASH = "https://login.microsoftonline.in:444/tfp/tenant/policy/"; - - - public static String INSTANCE_DISCOVERY_RESPONSE = "{" + "\"tenant_discovery_endpoint\":\"https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-appConfiguration\"," + "\"api-version\":\"1.1\"," + "\"metadata\":[{\"preferred_network\":\"login.microsoftonline.com\",\"preferred_cache\":\"login.windows.net\",\"aliases\":[\"login.microsoftonline.com\",\"login.windows.net\",\"login.microsoft.com\",\"sts.windows.net\"]},{\"preferred_network\":\"login.partner.microsoftonline.cn\",\"preferred_cache\":\"login.partner.microsoftonline.cn\",\"aliases\":[\"login.partner.microsoftonline.cn\",\"login.chinacloudapi.cn\"]},{\"preferred_network\":\"login.microsoftonline.de\",\"preferred_cache\":\"login.microsoftonline.de\",\"aliases\":[\"login.microsoftonline.de\"]},{\"preferred_network\":\"login.microsoftonline.us\",\"preferred_cache\":\"login.microsoftonline.us\",\"aliases\":[\"login.microsoftonline.us\",\"login.usgovcloudapi.net\"]},{\"preferred_network\":\"login-us.microsoftonline.com\",\"preferred_cache\":\"login-us.microsoftonline.com\",\"aliases\":[\"login-us.microsoftonline.com\"]}]}"; public final static String AAD_PREFERRED_NETWORK_ENV_ALIAS = "login.microsoftonline.com"; - public final static String AAD_PREFERRED_CACHE__ENV_ALIAS = "login.windows.net"; public final static String HTTP_RESPONSE_FROM_AUTH_CODE = "{\"access_token\":\"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6I" + "k5HVEZ2ZEstZnl0aEV1THdqcHdBSk9NOW4tQSJ9.eyJhdWQiOiJiN2E2NzFkOC1hNDA4LTQyZmYtODZlMC1hYWY0NDdmZDE3YzQiLCJpc3MiOiJod"