Skip to content

Commit

Permalink
Allow multiple signing keys to be provided (opensearch-project#4632)
Browse files Browse the repository at this point in the history
Signed-off-by: Stephen Crawford <steecraw@amazon.com>
  • Loading branch information
stephen-crawford authored Aug 20, 2024
1 parent d76bbfb commit e2cd610
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,25 @@ public class JwtAuthenticationTests {

public static final String QA_SONG_INDEX_NAME = String.format("song_lyrics_%s", QA_DEPARTMENT);

private static final KeyPair KEY_PAIR = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY = new String(Base64.getEncoder().encode(KEY_PAIR.getPublic().getEncoded()), US_ASCII);
private static final KeyPair KEY_PAIR1 = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY1 = new String(Base64.getEncoder().encode(KEY_PAIR1.getPublic().getEncoded()), US_ASCII);

private static final KeyPair KEY_PAIR2 = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY2 = new String(Base64.getEncoder().encode(KEY_PAIR2.getPublic().getEncoded()), US_ASCII);

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String JWT_AUTH_HEADER = "jwt-auth";

private static final JwtAuthorizationHeaderFactory tokenFactory = new JwtAuthorizationHeaderFactory(
KEY_PAIR.getPrivate(),
private static final JwtAuthorizationHeaderFactory tokenFactory1 = new JwtAuthorizationHeaderFactory(
KEY_PAIR1.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
JWT_AUTH_HEADER
);

private static final JwtAuthorizationHeaderFactory tokenFactory2 = new JwtAuthorizationHeaderFactory(
KEY_PAIR2.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
JWT_AUTH_HEADER
Expand All @@ -108,7 +118,10 @@ public class JwtAuthenticationTests {
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER)
.signingKey(List.of(PUBLIC_KEY1, PUBLIC_KEY2))
.subjectKey(CLAIM_USERNAME)
.rolesKey(CLAIM_ROLES)
).backend("noop");
public static final String SONG_ID_1 = "song-id-01";

Expand Down Expand Up @@ -143,7 +156,7 @@ public static void createTestData() {

@Test
public void shouldAuthenticateWithJwtToken_positive() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateValidToken(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateValidToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -155,7 +168,7 @@ public void shouldAuthenticateWithJwtToken_positive() {

@Test
public void shouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateValidToken(USERNAME_ROOT))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateValidToken(USERNAME_ROOT))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -167,7 +180,7 @@ public void shouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {

@Test
public void shouldAuthenticateWithJwtToken_failureLackingUserName() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateTokenWithoutPreferredUsername(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateTokenWithoutPreferredUsername(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -178,7 +191,7 @@ public void shouldAuthenticateWithJwtToken_failureLackingUserName() {

@Test
public void shouldAuthenticateWithJwtToken_failureExpiredToken() {
try (TestRestClient client = cluster.getRestClient(tokenFactory.generateExpiredToken(USER_SUPERHERO))) {
try (TestRestClient client = cluster.getRestClient(tokenFactory1.generateExpiredToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

Expand All @@ -202,7 +215,7 @@ public void shouldAuthenticateWithJwtToken_failureIncorrectFormatOfToken() {
@Test
public void shouldAuthenticateWithJwtToken_failureIncorrectSignature() {
KeyPair incorrectKeyPair = Keys.keyPairFor(SignatureAlgorithm.RS256);
Header header = tokenFactory.generateTokenSignedWithKey(incorrectKeyPair.getPrivate(), USER_SUPERHERO);
Header header = tokenFactory1.generateTokenSignedWithKey(incorrectKeyPair.getPrivate(), USER_SUPERHERO);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -214,7 +227,7 @@ public void shouldAuthenticateWithJwtToken_failureIncorrectSignature() {

@Test
public void shouldReadRolesFromToken_positiveFirstRoleSet() {
Header header = tokenFactory.generateValidToken(USER_SUPERHERO, ROLE_ADMIN, ROLE_DEVELOPER, ROLE_QA);
Header header = tokenFactory1.generateValidToken(USER_SUPERHERO, ROLE_ADMIN, ROLE_DEVELOPER, ROLE_QA);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -228,7 +241,7 @@ public void shouldReadRolesFromToken_positiveFirstRoleSet() {

@Test
public void shouldReadRolesFromToken_positiveSecondRoleSet() {
Header header = tokenFactory.generateValidToken(USER_SUPERHERO, ROLE_CTO, ROLE_CEO, ROLE_VP);
Header header = tokenFactory1.generateValidToken(USER_SUPERHERO, ROLE_CTO, ROLE_CEO, ROLE_VP);
try (TestRestClient client = cluster.getRestClient(header)) {

HttpResponse response = client.getAuthInfo();
Expand All @@ -244,7 +257,7 @@ public void shouldReadRolesFromToken_positiveSecondRoleSet() {
public void shouldExposeTokenClaimsAsUserAttributes_positive() throws IOException {
String[] roles = { ROLE_VP };
Map<String, Object> additionalClaims = Map.of(CLAIM_DEPARTMENT, QA_DEPARTMENT);
Header header = tokenFactory.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
try (RestHighLevelClient client = cluster.getRestHighLevelClient(List.of(header))) {
SearchRequest searchRequest = queryStringQueryRequest(QA_SONG_INDEX_NAME, QUERY_TITLE_MAGNUM_OPUS);

Expand All @@ -261,11 +274,36 @@ public void shouldExposeTokenClaimsAsUserAttributes_positive() throws IOExceptio
public void shouldExposeTokenClaimsAsUserAttributes_negative() throws IOException {
String[] roles = { ROLE_VP };
Map<String, Object> additionalClaims = Map.of(CLAIM_DEPARTMENT, "department-without-access-to-qa-song-index");
Header header = tokenFactory.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, roles, additionalClaims);
try (RestHighLevelClient client = cluster.getRestHighLevelClient(List.of(header))) {
SearchRequest searchRequest = queryStringQueryRequest(QA_SONG_INDEX_NAME, QUERY_TITLE_MAGNUM_OPUS);

assertThatThrownBy(() -> client.search(searchRequest, DEFAULT), statusException(FORBIDDEN));
}
}

@Test
public void secondKeypairShouldAuthenticateWithJwtToken_positive() {
try (TestRestClient client = cluster.getRestClient(tokenFactory2.generateValidToken(USER_SUPERHERO))) {

HttpResponse response = client.getAuthInfo();

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(USER_SUPERHERO));
}
}

@Test
public void secondKeypairShouldAuthenticateWithJwtToken_positiveWithAnotherUsername() {
try (TestRestClient client = cluster.getRestClient(tokenFactory2.generateValidToken(USERNAME_ROOT))) {

HttpResponse response = client.getAuthInfo();

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(USERNAME_ROOT));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ public class JwtAuthenticationWithUrlParamTests {
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM)
.signingKey(List.of(PUBLIC_KEY))
.subjectKey(CLAIM_USERNAME)
.rolesKey(CLAIM_ROLES)
).backend("noop");

@Rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*/
package org.opensearch.test.framework;

import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand All @@ -19,7 +20,7 @@
public class JwtConfigBuilder {
private String jwtHeader;
private String jwtUrlParameter;
private String signingKey;
private List<String> signingKeys;
private String subjectKey;
private String rolesKey;

Expand All @@ -33,8 +34,8 @@ public JwtConfigBuilder jwtUrlParameter(String jwtUrlParameter) {
return this;
}

public JwtConfigBuilder signingKey(String signingKey) {
this.signingKey = signingKey;
public JwtConfigBuilder signingKey(List<String> signingKeys) {
this.signingKeys = signingKeys;
return this;
}

Expand All @@ -50,10 +51,10 @@ public JwtConfigBuilder rolesKey(String rolesKey) {

public Map<String, Object> build() {
Builder<String, Object> builder = new Builder<>();
if (Objects.isNull(signingKey)) {
if (Objects.isNull(signingKeys)) {
throw new IllegalStateException("Signing key is required.");
}
builder.put("signing_key", signingKey);
builder.put("signing_key", signingKeys);
if (isNoneBlank(jwtHeader)) {
builder.put("jwt_header", jwtHeader);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ public static class AuthcDomain implements ToXContentObject {
).httpAuthenticator("basic").backend("internal");

public final static AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain("jwt", 1).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtHeader(AUTHORIZATION).signingKey(PUBLIC_KEY)
new JwtConfigBuilder().jwtHeader(AUTHORIZATION).signingKey(List.of(PUBLIC_KEY))
).backend("noop");

private final String id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -54,7 +55,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator {
private static final Pattern BASIC = Pattern.compile("^\\s*Basic\\s.*", Pattern.CASE_INSENSITIVE);
private static final String BEARER = "bearer ";

private final JwtParser jwtParser;
private final List<JwtParser> jwtParsers = new ArrayList<>();
private final String jwtHeaderName;
private final boolean isDefaultAuthHeader;
private final String jwtUrlParameter;
Expand All @@ -67,7 +68,8 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator {
public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
super();

String signingKey = settings.get("signing_key");
List<String> signingKeys = settings.getAsList("signing_key");

jwtUrlParameter = settings.get("jwt_url_parameter");
jwtHeaderName = settings.get("jwt_header", AUTHORIZATION);
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
Expand All @@ -83,19 +85,23 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
);
}

final JwtParserBuilder jwtParserBuilder = KeyUtils.createJwtParserBuilderFromSigningKey(signingKey, log);
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}

final SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());
for (String key : signingKeys) {
JwtParser jwtParser;
final JwtParserBuilder jwtParserBuilder = KeyUtils.createJwtParserBuilderFromSigningKey(key, log);
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}

final SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());
}
jwtParser = AccessController.doPrivileged((PrivilegedAction<JwtParser>) jwtParserBuilder::build);
}
jwtParser = AccessController.doPrivileged((PrivilegedAction<JwtParser>) jwtParserBuilder::build);
jwtParsers.add(jwtParser);
}
}

Expand All @@ -120,7 +126,8 @@ public AuthCredentials run() {
}

private AuthCredentials extractCredentials0(final SecurityRequest request) {
if (jwtParser == null) {

if (jwtParsers.isEmpty() || jwtParsers.getFirst() == null) {
log.error("Missing Signing Key. JWT authentication will not work");
return null;
}
Expand Down Expand Up @@ -157,39 +164,43 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) {
}
}

try {
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();
for (JwtParser jwtParser : jwtParsers) {
try {

if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();

final String subject = extractSubject(claims, request);
if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}

if (subject == null) {
log.error("No subject found in JWT token");
return null;
}
final String subject = extractSubject(claims, request);

final String[] roles = extractRoles(claims, request);
if (subject == null) {
log.error("No subject found in JWT token");
return null;
}

final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();
final String[] roles = extractRoles(claims, request);

for (Entry<String, Object> claim : claims.entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}
final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();

return ac;
for (Entry<String, Object> claim : claims.entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}

} catch (WeakKeyException e) {
log.error("Cannot authenticate user with JWT because of ", e);
return null;
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Invalid or expired JWT token.", e);
return ac;

} catch (WeakKeyException e) {
log.error("Cannot authenticate user with JWT because of ", e);
return null;
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Invalid or expired JWT token.", e);
}
}
return null;
}
log.error("Failed to parse JWT token using any of the available parsers");
return null;
}

private void assertValidAudienceClaim(Claims claims) throws BadJWTException {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/security/util/KeyUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public JwtParserBuilder run() {
PublicKey key = null;

final String minimalKeyFormat = signingKey.replace("-----BEGIN PUBLIC KEY-----\n", "")
.replace("-----END PUBLIC KEY-----", "");

.replace("-----END PUBLIC KEY-----", "")
.trim();
final byte[] decoded = Base64.getDecoder().decode(minimalKeyFormat);

try {
Expand Down
Loading

0 comments on commit e2cd610

Please sign in to comment.