Skip to content

Commit

Permalink
Fix for private jwt client to rebuild when expired (#221)
Browse files Browse the repository at this point in the history
The PrivateJwt assertion with a certificate is generated once
This means that when a request for a new OAuth Token is made that
the client assertion has expired and fails. This change
fixes that by reubuilding the assersion for a private jwt when
it has expired
  • Loading branch information
Budlee authored May 5, 2020
1 parent 9479deb commit f57d4dd
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import java.net.MalformedURLException;
import java.net.Proxy;
import java.net.URL;
import java.util.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
Expand All @@ -28,7 +31,6 @@
abstract class ClientApplicationBase implements IClientApplicationBase {

protected Logger log;
protected ClientAuthentication clientAuthentication;
protected Authority authenticationAuthority;
private ServiceBundle serviceBundle;

Expand Down Expand Up @@ -80,6 +82,8 @@ abstract class ClientApplicationBase implements IClientApplicationBase {
@Getter
private AadInstanceDiscoveryResponse aadAadInstanceDiscoveryResponse;

protected abstract ClientAuthentication clientAuthentication();

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(AuthorizationCodeParameters parameters) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -23,11 +24,15 @@
* Class to be used to acquire tokens for confidential client applications (Web Apps, Web APIs,
* and daemon applications).
* For details see {@link IConfidentialClientApplication}
*
* <p>
* Conditionally thread-safe
*/
public class ConfidentialClientApplication extends ClientApplicationBase implements IConfidentialClientApplication {

private ClientAuthentication clientAuthentication;
private boolean clientCertAuthentication = false;
private ClientCertificate clientCertificate;

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(ClientCredentialParameters parameters) {

Expand Down Expand Up @@ -71,19 +76,37 @@ private void initClientAuthentication(IClientCredential clientCredential) {
new ClientID(clientId()),
new Secret(((ClientSecret) clientCredential).clientSecret()));
} else if (clientCredential instanceof ClientCertificate) {
ClientAssertion clientAssertion = JwtHelper.buildJwt(
clientId(),
(ClientCertificate) clientCredential,
this.authenticationAuthority.selfSignedJwtAudience());

clientAuthentication = createClientAuthFromClientAssertion(clientAssertion);
} else if (clientCredential instanceof ClientAssertion){
this.clientCertAuthentication = true;
this.clientCertificate = (ClientCertificate) clientCredential;
clientAuthentication = buildValidClientCertificateAuthority();
} else if (clientCredential instanceof ClientAssertion) {
clientAuthentication = createClientAuthFromClientAssertion((ClientAssertion) clientCredential);
} else {
throw new IllegalArgumentException("Unsupported client credential");
}
}

@Override
protected ClientAuthentication clientAuthentication() {
if (clientCertAuthentication) {
final Date currentDateTime = new Date(System.currentTimeMillis());
final Date expirationTime = ((PrivateKeyJWT) clientAuthentication).getJWTAuthenticationClaimsSet().getExpirationTime();
if (expirationTime.before(currentDateTime)) {
//The asserted private jwt with the client certificate can expire so rebuild it when the
clientAuthentication = buildValidClientCertificateAuthority();
}
}
return clientAuthentication;
}

private ClientAuthentication buildValidClientCertificateAuthority() {
ClientAssertion clientAssertion = JwtHelper.buildJwt(
clientId(),
clientCertificate,
this.authenticationAuthority.selfSignedJwtAudience());
return createClientAuthFromClientAssertion(clientAssertion);
}

private ClientAuthentication createClientAuthFromClientAssertion(
final ClientAssertion clientAssertion) {
try {
Expand All @@ -102,7 +125,6 @@ private ClientAuthentication createClientAuthFromClientAssertion(
* @param clientId Client ID (Application ID) of the application as registered
* in the application registration portal (portal.azure.com)
* @param clientCredential The client credential to use for token acquisition.
*
* @return instance of Builder of ConfidentialClientApplication
*/
public static Builder builder(String clientId, IClientCredential clientCredential) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

package com.microsoft.aad.msal4j;

import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod;
import com.nimbusds.oauth2.sdk.id.ClientID;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.experimental.Accessors;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;
Expand All @@ -24,6 +22,8 @@
*/
public class PublicClientApplication extends ClientApplicationBase implements IPublicClientApplication {

private final ClientAuthenticationPost clientAuthentication;

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(UserNamePasswordParameters parameters) {

Expand Down Expand Up @@ -96,17 +96,15 @@ public CompletableFuture<IAuthenticationResult> acquireToken(InteractiveRequestP

private PublicClientApplication(Builder builder) {
super(builder);

validateNotBlank("clientId", clientId());
log = LoggerFactory.getLogger(PublicClientApplication.class);

initClientAuthentication(clientId());
this.clientAuthentication = new ClientAuthenticationPost(ClientAuthenticationMethod.NONE,
new ClientID(clientId()));
}

private void initClientAuthentication(String clientId) {
validateNotBlank("clientId", clientId);

clientAuthentication = new ClientAuthenticationPost(ClientAuthenticationMethod.NONE,
new ClientID(clientId));
@Override
protected ClientAuthentication clientAuthentication() {
return clientAuthentication;
}

/**
Expand Down
19 changes: 9 additions & 10 deletions src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@

package com.microsoft.aad.msal4j;

import java.io.IOException;
import java.net.MalformedURLException;
import java.util.Date;
import java.util.List;
import java.util.Map;

import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.SerializeException;
import com.nimbusds.oauth2.sdk.http.CommonContentTypes;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.util.URLUtils;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import lombok.AccessLevel;
import lombok.Getter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import lombok.AccessLevel;
import lombok.Getter;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.Date;
import java.util.List;
import java.util.Map;

@Getter(AccessLevel.PACKAGE)
class TokenRequestExecutor {
Expand Down Expand Up @@ -61,8 +60,8 @@ OAuthHttpRequest createOauthHttpRequest() throws SerializeException, MalformedUR
final Map<String, List<String>> params = msalRequest.msalAuthorizationGrant().toParameters();
oauthHttpRequest.setQuery(URLUtils.serializeParameters(params));

if (msalRequest.application().clientAuthentication != null) {
msalRequest.application().clientAuthentication.applyTo(oauthHttpRequest);
if (msalRequest.application().clientAuthentication() != null) {
msalRequest.application().clientAuthentication().applyTo(oauthHttpRequest);
}
return oauthHttpRequest;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

package com.microsoft.aad.msal4j;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.util.Base64;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
import org.easymock.EasyMock;
import org.powermock.api.easymock.PowerMock;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
Expand All @@ -16,27 +24,30 @@
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Future;

import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;

@PowerMockIgnore({"javax.net.ssl.*"})
@Test(groups = { "checkin" })
@PrepareForTest({ ConfidentialClientApplication.class,
ClientCertificate.class, UserDiscoveryRequest.class })
@Test(groups = {"checkin"})
@PrepareForTest({ConfidentialClientApplication.class,
ClientCertificate.class, UserDiscoveryRequest.class, JwtHelper.class})
public class ConfidentialClientApplicationTest extends PowerMockTestCase {

private ConfidentialClientApplication app = null;

@Test
public void testAcquireTokenAuthCode_ClientCredential() throws Exception {
app = PowerMock.createPartialMock(ConfidentialClientApplication.class,
new String[] { "acquireTokenCommon" },
new String[]{"acquireTokenCommon"},
ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID,
ClientCredentialFactory.createFromSecret(TestConfiguration.AAD_CLIENT_SECRET))
ClientCredentialFactory.createFromSecret(TestConfiguration.AAD_CLIENT_SECRET))
.authority(TestConfiguration.AAD_TENANT_ENDPOINT)
);

Expand Down Expand Up @@ -78,10 +89,10 @@ public void testAcquireTokenAuthCode_KeyCredential() throws Exception {
final X509Certificate cert = (X509Certificate) keystore
.getCertificate(alias);

IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert);
IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert);

app = PowerMock.createPartialMock(ConfidentialClientApplication.class,
new String[] { "acquireTokenCommon" },
new String[]{"acquireTokenCommon"},
ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential)
.authority(TestConfiguration.AAD_TENANT_ENDPOINT));

Expand Down Expand Up @@ -123,10 +134,10 @@ public void testAcquireToken_KeyCred() throws Exception {
final X509Certificate cert = (X509Certificate) keystore
.getCertificate(alias);

IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert);
IClientCredential clientCredential = ClientCredentialFactory.createFromCertificate(key, cert);

app = PowerMock.createPartialMock(ConfidentialClientApplication.class,
new String[] { "acquireTokenCommon" },
new String[]{"acquireTokenCommon"},
ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential)
.authority(TestConfiguration.AAD_TENANT_ENDPOINT));

Expand All @@ -153,4 +164,70 @@ public void testAcquireToken_KeyCred() throws Exception {
PowerMock.verifyAll();
PowerMock.resetAll(app);
}

@Test
public void testClientCertificateRebuildsWhenExpired() throws Exception {
final KeyStore keystore = KeyStore.getInstance("PKCS12", "SunJSSE");
keystore.load(
new FileInputStream(this.getClass()
.getResource(TestConfiguration.AAD_CERTIFICATE_PATH)
.getFile()),
TestConfiguration.AAD_CERTIFICATE_PASSWORD.toCharArray());
final String alias = keystore.aliases().nextElement();
final PrivateKey key = (PrivateKey) keystore.getKey(alias,
TestConfiguration.AAD_CERTIFICATE_PASSWORD.toCharArray());
final X509Certificate cert = (X509Certificate) keystore
.getCertificate(alias);

ClientCertificate clientCredential = (ClientCertificate) ClientCredentialFactory.createFromCertificate(key, cert);

PowerMock.mockStaticPartial(JwtHelper.class, new String[]{"buildJwt"});
long jwtExperiationPeriodMilli = 2000;
ClientAssertion shortExperationJwt = buildShortJwt(TestConfiguration.AAD_CLIENT_ID, clientCredential, TestConfiguration.AAD_TENANT_ENDPOINT, jwtExperiationPeriodMilli);
PowerMock.expectPrivate(
JwtHelper.class,
"buildJwt",
EasyMock.isA(String.class),
EasyMock.isA(ClientCertificate.class),
EasyMock.isA(String.class))
.andReturn(shortExperationJwt)
.times(2); // By this being called twice we ensure the client assertion is rebuilt once it has expired

PowerMock.replay(JwtHelper.class);
app = ConfidentialClientApplication.builder(TestConfiguration.AAD_CLIENT_ID, clientCredential)
.authority(TestConfiguration.AAD_TENANT_ENDPOINT).build();
Thread.sleep(jwtExperiationPeriodMilli + 1000); //Have to sleep to ensure that the time period has passed
final PrivateKeyJWT clientAuthentication = (PrivateKeyJWT) app.clientAuthentication();
assertNotNull(clientAuthentication);
PowerMock.verifyAll();
}

private ClientAssertion buildShortJwt(String clientId,
ClientCertificate credential,
String jwtAudience,
long jwtExperiationPeriod) {
final long time = System.currentTimeMillis();
final JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
.audience(Collections.singletonList(jwtAudience))
.issuer(clientId)
.jwtID(UUID.randomUUID().toString())
.notBeforeTime(new Date(time))
.expirationTime(new Date(time + jwtExperiationPeriod))
.subject(clientId)
.build();
SignedJWT jwt;
try {
List<Base64> certs = new ArrayList<>();
certs.add(new Base64(credential.publicCertificate()));
JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.RS256);
builder.x509CertChain(certs);
builder.x509CertThumbprint(new Base64URL(credential.publicCertificateHash()));
jwt = new SignedJWT(builder.build(), claimsSet);
final RSASSASigner signer = new RSASSASigner(credential.key());
jwt.sign(signer);
} catch (final Exception e) {
throw new MsalClientException(e);
}
return new ClientAssertion(jwt.serialize());
}
}
5 changes: 0 additions & 5 deletions src/test/java/com/microsoft/aad/msal4j/TestConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ public final class TestConfiguration {
public final static String AAD_TOKEN_SUCCESS_FILE = "/token.xml";
public final static String AAD_CERTIFICATE_PASSWORD = "password";
public final static String AAD_DEFAULT_REDIRECT_URI = "https://non_existing_uri.windows.com/";
public final static String AAD_REDIRECT_URI_FOR_CONFIDENTIAL_CLIENT = "https://non_existing_uri_for_confidential_client.com/";
public final static String AAD_COMMON_AUTHORITY = "https://login.microsoftonline.com/common/";

public final static String ADFS_HOST_NAME = "fs.ade2eadfs30.com";
Expand All @@ -39,16 +38,12 @@ public final class TestConfiguration {
public final static String B2C_AUTHORITY_CUSTOM_PORT = "https://login.microsoftonline.in:444/tfp/tenant/policy";
public final static String B2C_AUTHORITY_CUSTOM_PORT_TAIL_SLASH = "https://login.microsoftonline.in:444/tfp/tenant/policy/";




public static String INSTANCE_DISCOVERY_RESPONSE = "{" +
"\"tenant_discovery_endpoint\":\"https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-appConfiguration\"," +
"\"api-version\":\"1.1\"," +
"\"metadata\":[{\"preferred_network\":\"login.microsoftonline.com\",\"preferred_cache\":\"login.windows.net\",\"aliases\":[\"login.microsoftonline.com\",\"login.windows.net\",\"login.microsoft.com\",\"sts.windows.net\"]},{\"preferred_network\":\"login.partner.microsoftonline.cn\",\"preferred_cache\":\"login.partner.microsoftonline.cn\",\"aliases\":[\"login.partner.microsoftonline.cn\",\"login.chinacloudapi.cn\"]},{\"preferred_network\":\"login.microsoftonline.de\",\"preferred_cache\":\"login.microsoftonline.de\",\"aliases\":[\"login.microsoftonline.de\"]},{\"preferred_network\":\"login.microsoftonline.us\",\"preferred_cache\":\"login.microsoftonline.us\",\"aliases\":[\"login.microsoftonline.us\",\"login.usgovcloudapi.net\"]},{\"preferred_network\":\"login-us.microsoftonline.com\",\"preferred_cache\":\"login-us.microsoftonline.com\",\"aliases\":[\"login-us.microsoftonline.com\"]}]}";

public final static String AAD_PREFERRED_NETWORK_ENV_ALIAS = "login.microsoftonline.com";
public final static String AAD_PREFERRED_CACHE__ENV_ALIAS = "login.windows.net";

public final static String HTTP_RESPONSE_FROM_AUTH_CODE = "{\"access_token\":\"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6I"
+ "k5HVEZ2ZEstZnl0aEV1THdqcHdBSk9NOW4tQSJ9.eyJhdWQiOiJiN2E2NzFkOC1hNDA4LTQyZmYtODZlMC1hYWY0NDdmZDE3YzQiLCJpc3MiOiJod"
Expand Down

0 comments on commit f57d4dd

Please sign in to comment.