Skip to content

Commit

Permalink
Switch JWT library implementations from cxf to nimbus (#3421)
Browse files Browse the repository at this point in the history
Switch from org.apache.cxf.rs.security.jose to com.nimbusds.jose.jwk.

Signed-off-by: Peter Nied <petern@amazon.com>
Signed-off-by: Maciej Mierzwa <dev.maciej.mierzwa@gmail.com>
Signed-off-by: Ryan Liang <jiallian@amazon.com>
Co-authored-by: Peter Nied <petern@amazon.com>
Co-authored-by: Ryan Liang <jiallian@amazon.com>
  • Loading branch information
3 people authored Oct 24, 2023
1 parent ccc3e34 commit 4f89b4a
Show file tree
Hide file tree
Showing 22 changed files with 672 additions and 537 deletions.
4 changes: 1 addition & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ dependencies {
implementation 'commons-cli:commons-cli:1.5.0'
implementation "org.bouncycastle:bcprov-jdk15to18:${versions.bouncycastle}"
implementation 'org.ldaptive:ldaptive:1.2.3'
implementation 'com.nimbusds:nimbus-jose-jwt:9.31'

//JWT
implementation "io.jsonwebtoken:jjwt-api:${jjwt_version}"
Expand All @@ -581,9 +582,6 @@ dependencies {

runtimeOnly 'net.minidev:accessors-smart:2.5.0'

implementation("org.apache.cxf:cxf-rt-rs-security-jose:${apache_cxf_version}") {
exclude(group: 'jakarta.activation', module: 'jakarta.activation-api')
}
runtimeOnly "org.apache.cxf:cxf-core:${apache_cxf_version}"
implementation "org.apache.cxf:cxf-rt-rs-json-basic:${apache_cxf_version}"
runtimeOnly "org.apache.cxf:cxf-rt-security:${apache_cxf_version}"
Expand Down
1 change: 1 addition & 0 deletions checkstyle/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
</module>
<module name="IllegalImport"> <!-- defaults to sun.* packages -->
<property name="severity" value="error"/>
<property name="illegalPkgs" value="org.apache.cxf.rs.security.jose"/>
</module>
<module name="RedundantImport">
<property name="severity" value="error"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.Map.Entry;
import java.util.regex.Pattern;

import com.google.common.annotations.VisibleForTesting;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -112,37 +113,34 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) throw
return null;
}

JwtToken jwt;
SignedJWT jwt;
JWTClaimsSet claimsSet;

try {
jwt = jwtVerifier.getVerifiedJwtToken(jwtString);
claimsSet = jwt.getJWTClaimsSet();
} catch (AuthenticatorUnavailableException e) {
log.info(e.toString());
throw new OpenSearchSecurityException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE);
} catch (BadCredentialsException e) {
} catch (BadCredentialsException | ParseException e) {
log.info("Extracting JWT token from {} failed", jwtString, e);
return null;
}

JwtClaims claims = jwt.getClaims();

final String subject = extractSubject(claims);

final String subject = extractSubject(claimsSet);
if (subject == null) {
log.error("No subject found in JWT token");
return null;
}

final String[] roles = extractRoles(claims);

final String[] roles = extractRoles(claimsSet);
final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();

for (Entry<String, Object> claim : claims.asMap().entrySet()) {
for (Entry<String, Object> claim : claimsSet.getClaims().entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}

return ac;

}

protected String getJwtTokenString(SecurityRequest request) {
Expand Down Expand Up @@ -174,7 +172,7 @@ protected String getJwtTokenString(SecurityRequest request) {
}

@VisibleForTesting
public String extractSubject(JwtClaims claims) {
public String extractSubject(JWTClaimsSet claims) {
String subject = claims.getSubject();

if (subjectKey != null) {
Expand Down Expand Up @@ -204,7 +202,7 @@ public String extractSubject(JwtClaims claims) {

@SuppressWarnings("unchecked")
@VisibleForTesting
public String[] extractRoles(JwtClaims claims) {
public String[] extractRoles(JWTClaimsSet claims) {
if (rolesKey == null) {
return new String[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import com.google.common.base.Strings;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.proc.SimpleSecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.KeyType;
import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse;
import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jws.JwsSignatureVerifier;
import org.apache.cxf.rs.security.jose.jws.JwsUtils;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtException;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import org.apache.cxf.rs.security.jose.jwt.JwtUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.text.ParseException;
import java.util.Collections;

public class JwtVerifier {

private final static Logger log = LogManager.getLogger(JwtVerifier.class);
Expand All @@ -43,31 +46,24 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin
this.requiredAudience = requiredAudience;
}

public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
try {
JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt);
JwtToken jwt = jwtConsumer.getJwtToken();
SignedJWT jwt = SignedJWT.parse(encodedJwt);

String escapedKid = jwt.getJwsHeaders().getKeyId();
String escapedKid = jwt.getHeader().getKeyID();
String kid = escapedKid;
if (!Strings.isNullOrEmpty(kid)) {
kid = StringEscapeUtils.unescapeJava(escapedKid);
}
JsonWebKey key = keyProvider.getKey(kid);

// Algorithm is not mandatory for the key material, so we set it to the same as the JWT
if (key.getAlgorithm() == null && key.getPublicKeyUse() == PublicKeyUse.SIGN && key.getKeyType() == KeyType.RSA) {
key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm());
}

JwsSignatureVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
JWK key = keyProvider.getKey(kid);

boolean signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
boolean signatureValid = jwt.verify(signatureVerifier);

if (!signatureValid && Strings.isNullOrEmpty(kid)) {
key = keyProvider.getKeyAfterRefresh(null);
signatureVerifier = getInitializedSignatureVerifier(key, jwt);
signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
signatureValid = jwt.verify(signatureVerifier);
}

if (!signatureValid) {
Expand All @@ -77,18 +73,18 @@ public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExce
validateClaims(jwt);

return jwt;
} catch (JwtException e) {
} catch (JOSEException | ParseException | BadJWTException e) {
throw new BadCredentialsException(e.getMessage(), e);
}
}

private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws BadCredentialsException {
if (Strings.isNullOrEmpty(key.getAlgorithm())) {
private void validateSignatureAlgorithm(JWK key, SignedJWT jwt) throws BadCredentialsException {
if (key.getAlgorithm() == null || jwt.getHeader().getAlgorithm() == null) {
return;
}

SignatureAlgorithm keyAlgorithm = SignatureAlgorithm.getAlgorithm(key.getAlgorithm());
SignatureAlgorithm tokenAlgorithm = SignatureAlgorithm.getAlgorithm(jwt.getJwsHeaders().getAlgorithm());
Algorithm keyAlgorithm = key.getAlgorithm();
Algorithm tokenAlgorithm = jwt.getHeader().getAlgorithm();

if (!keyAlgorithm.equals(tokenAlgorithm)) {
throw new BadCredentialsException(
Expand All @@ -97,38 +93,48 @@ private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws Bad
}
}

private JwsSignatureVerifier getInitializedSignatureVerifier(JsonWebKey key, JwtToken jwt) throws BadCredentialsException,
JwtException {
private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) throws BadCredentialsException, JOSEException {

validateSignatureAlgorithm(key, jwt);
JwsSignatureVerifier result = JwsUtils.getSignatureVerifier(key, jwt.getJwsHeaders().getSignatureAlgorithm());
final JWSVerifier result;
if (key.getClass() == OctetSequenceKey.class) {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toOctetSequenceKey().toSecretKey());
} else {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toRSAKey().toRSAPublicKey());
}

if (result == null) {
throw new BadCredentialsException("Cannot verify JWT");
} else {
return result;
}
}

private void validateClaims(JwtToken jwt) throws JwtException {
JwtClaims claims = jwt.getClaims();
private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTException {
JWTClaimsSet claims = jwt.getJWTClaimsSet();

if (claims != null) {
JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false);
JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false);
DefaultJWTClaimsVerifier<SimpleSecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>(
requiredAudience,
null,
Collections.emptySet()
);
claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds);
claimsVerifier.verify(claims, null);
validateRequiredAudienceAndIssuer(claims);
}
}

private void validateRequiredAudienceAndIssuer(JwtClaims claims) {
String audience = claims.getAudience();
private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
String audience = claims.getAudience().stream().findFirst().orElse("");
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
throw new JwtException("Invalid audience");
throw new BadJWTException("Invalid audience");
}

if (!Strings.isNullOrEmpty(requiredIssuer) && !requiredIssuer.equals(issuer)) {
throw new JwtException("Invalid issuer");
throw new BadJWTException("Invalid issuer");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import com.nimbusds.jose.jwk.JWK;

public interface KeyProvider {
public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;

public JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import com.nimbusds.jose.jwk.JWKSet;

@FunctionalInterface
public interface KeySetProvider {
JsonWebKeys get() throws AuthenticatorUnavailableException;
JWKSet get() throws AuthenticatorUnavailableException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import java.io.IOException;
import java.text.ParseException;
import java.util.concurrent.TimeUnit;

import com.nimbusds.jose.jwk.JWKSet;
import joptsimple.internal.Strings;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.JwkUtils;
import org.apache.hc.client5.http.cache.HttpCacheContext;
import org.apache.hc.client5.http.cache.HttpCacheStorage;
import org.apache.hc.client5.http.classic.methods.HttpGet;
Expand Down Expand Up @@ -70,7 +70,7 @@ public class KeySetRetriever implements KeySetProvider {
configureCache(useCacheForOidConnectEndpoint);
}

public JsonWebKeys get() throws AuthenticatorUnavailableException {
public JWKSet get() throws AuthenticatorUnavailableException {
String uri = getJwksUri();

try (CloseableHttpClient httpClient = createHttpClient(null)) {
Expand All @@ -94,10 +94,11 @@ public JsonWebKeys get() throws AuthenticatorUnavailableException {
if (httpEntity == null) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity");
}

JsonWebKeys keySet = JwkUtils.readJwkSet(httpEntity.getContent());
JWKSet keySet = JWKSet.load(httpEntity.getContent());

return keySet;
} catch (ParseException e) {
throw new RuntimeException(e);
}
} catch (IOException e) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + e, e);
Expand Down
Loading

0 comments on commit 4f89b4a

Please sign in to comment.