Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OnBehalfOf tokens odds and ends #3593

Merged
merged 4 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void shouldLoadDefaultConfiguration() {
Awaitility.await().alias("Load default configuration").until(() -> client.getAuthInfo().getStatusCode(), equalTo(200));
}
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
HttpResponse response = client.get("/_plugins/_security/api/internalusers");
response.assertStatusCode(200);
Map<String, Object> users = response.getBodyAs(Map.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ public void shouldCreateUserViaRestApi_success() {
assertThat(httpResponse.getStatusCode(), equalTo(201));
}
try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(ADDITIONAL_USER_1, ADDITIONAL_PASSWORD_1)) {
client.assertCorrectCredentials(ADDITIONAL_USER_1);
client.confirmCorrectCredentials(ADDITIONAL_USER_1);
}
}

Expand Down Expand Up @@ -160,10 +160,10 @@ public void shouldCreateUserViaRestApiWhenAdminIsAuthenticatedViaCertificate_pos
httpResponse.assertStatusCode(201);
}
try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(ADDITIONAL_USER_2, ADDITIONAL_PASSWORD_2)) {
client.assertCorrectCredentials(ADDITIONAL_USER_2);
client.confirmCorrectCredentials(ADDITIONAL_USER_2);
}
}

Expand All @@ -189,10 +189,10 @@ public void shouldStillWorkAfterUpdateOfSecurityConfig() {
cluster.updateUserConfiguration(users);

try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(newUser)) {
client.assertCorrectCredentials(newUser.getName());
client.confirmCorrectCredentials(newUser.getName());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void shouldAuthenticateUserWithCertificate_positiveUserSpoke() {
CertificateData userSpockCertificate = TEST_CERTIFICATES.issueUserCertificate(BACKEND_ROLE_BRIDGE, USER_SPOCK);
try (TestRestClient client = cluster.getRestClient(userSpockCertificate)) {

client.assertCorrectCredentials(USER_SPOCK);
client.confirmCorrectCredentials(USER_SPOCK);
}
}

Expand All @@ -98,7 +98,7 @@ public void shouldAuthenticateUserWithCertificate_positiveUserKirk() {
CertificateData userSpockCertificate = TEST_CERTIFICATES.issueUserCertificate(BACKEND_ROLE_BRIDGE, USER_KIRK);
try (TestRestClient client = cluster.getRestClient(userSpockCertificate)) {

client.assertCorrectCredentials(USER_KIRK);
client.confirmCorrectCredentials(USER_KIRK);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
Expand All @@ -24,7 +25,6 @@
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -38,11 +38,10 @@
import org.opensearch.test.framework.cluster.TestRestClient;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.contains;
import static org.opensearch.security.support.ConfigConstants.SECURITY_ALLOW_DEFAULT_INIT_SECURITYINDEX;
import static org.opensearch.security.support.ConfigConstants.SECURITY_RESTAPI_ROLES_ENABLED;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
Expand All @@ -56,6 +55,7 @@ public class OnBehalfOfJwtAuthenticationTest {

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String CREATE_OBO_TOKEN_PATH = "_plugins/_security/api/generateonbehalfoftoken";
private static Boolean oboEnabled = true;
private static final String signingKey = Base64.getEncoder()
.encodeToString(
Expand Down Expand Up @@ -139,7 +139,7 @@ public void shouldNotAuthenticateForUsingOBOTokenToAccessOBOEndpoint() {
Header adminOboAuthHeader = new BasicHeader("Authorization", "Bearer " + oboToken);

try (TestRestClient client = cluster.getRestClient(adminOboAuthHeader)) {
TestRestClient.HttpResponse response = client.getOnBehalfOfToken(OBO_DESCRIPTION, adminOboAuthHeader);
TestRestClient.HttpResponse response = client.postJson(CREATE_OBO_TOKEN_PATH, OBO_DESCRIPTION);
response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED);
}
}
Expand All @@ -150,7 +150,7 @@ public void shouldNotAuthenticateForUsingOBOTokenToAccessAccountEndpoint() {
Header adminOboAuthHeader = new BasicHeader("Authorization", "Bearer " + oboToken);

try (TestRestClient client = cluster.getRestClient(adminOboAuthHeader)) {
TestRestClient.HttpResponse response = client.changeInternalUserPassword(CURRENT_AND_NEW_PASSWORDS, adminOboAuthHeader);
TestRestClient.HttpResponse response = client.putJson("_plugins/_security/api/account", CURRENT_AND_NEW_PASSWORDS);
response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED);
}
}
Expand All @@ -173,52 +173,50 @@ public void shouldNotAuthenticateForNonAdminUserWithoutOBOPermission() {
public void shouldNotIncludeRolesFromHostMappingInOBOToken() {
String oboToken = generateOboToken(OBO_USER_NAME_WITH_HOST_MAPPING, DEFAULT_PASSWORD);

Claims claims = Jwts.parserBuilder().setSigningKey(signingKey).build().parseClaimsJws(oboToken).getBody();
Claims claims = Jwts.parserBuilder()
.setSigningKey(Base64.getDecoder().decode(signingKey))
.build()
.parseClaimsJws(oboToken)
.getBody();

Object er = claims.get("er");
EncryptionDecryptionUtil encryptionDecryptionUtil = new EncryptionDecryptionUtil(encryptionKey);
String rolesClaim = encryptionDecryptionUtil.decrypt(er.toString());
List<String> roles = Arrays.stream(rolesClaim.split(","))
.map(String::trim)
.filter(s -> !s.isEmpty())
.collect(Collectors.toUnmodifiableList());
Set<String> roles = Arrays.stream(rolesClaim.split(",")).map(String::trim).filter(s -> !s.isEmpty()).collect(Collectors.toSet());

Assert.assertFalse(roles.contains("host_mapping_role"));
assertThat(roles, equalTo(HOST_MAPPING_OBO_USER.getRoleNames()));
assertThat(roles, not(contains("host_mapping_role")));
}

@Test
public void shouldNotAuthenticateWithInvalidDurationSeconds() {
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_DESCRIPTION_WITH_INVALID_DURATIONSECONDS);
response.assertStatusCode(HttpStatus.SC_BAD_REQUEST);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertTrue(oboEndPointResponse.containsValue("durationSeconds must be an integer."));
assertThat(response.getTextFromJsonBody("/error"), equalTo("durationSeconds must be an integer."));
}
}

@Test
public void shouldNotAuthenticateWithInvalidAPIParameter() {
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_DESCRIPTION_WITH_INVALID_PARAMETERS);
response.assertStatusCode(HttpStatus.SC_BAD_REQUEST);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertTrue(oboEndPointResponse.containsValue("Unrecognized parameter: invalidParameter"));
assertThat(response.getTextFromJsonBody("/error"), equalTo("Unrecognized parameter: invalidParameter"));
}
}

private String generateOboToken(String username, String password) {
try (TestRestClient client = cluster.getRestClient(username, password)) {
client.assertCorrectCredentials(username);
client.confirmCorrectCredentials(username);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_TOKEN_REASON);
response.assertStatusCode(HttpStatus.SC_OK);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertThat(
oboEndPointResponse,
allOf(aMapWithSize(3), hasKey("user"), hasKey("authenticationToken"), hasKey("durationSeconds"))
);
return oboEndPointResponse.get("authenticationToken").toString();
assertThat(response.getTextFromJsonBody("/user"), notNullValue());
assertThat(response.getTextFromJsonBody("/authenticationToken"), notNullValue());
assertThat(response.getTextFromJsonBody("/durationSeconds"), notNullValue());
return response.getTextFromJsonBody("/authenticationToken").toString();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,7 @@ public HttpResponse getAuthInfo(Header... headers) {
return executeRequest(new HttpGet(getHttpServerUri() + "/_opendistro/_security/authinfo?pretty"), headers);
}

public HttpResponse getOnBehalfOfToken(String jsonData, Header... headers) {
try {
HttpPost httpPost = new HttpPost(
new URIBuilder(getHttpServerUri() + "/_plugins/_security/api/generateonbehalfoftoken?pretty").build()
);
httpPost.setEntity(new StringEntity(jsonData));
return executeRequest(httpPost, mergeHeaders(CONTENT_TYPE_JSON, headers));
} catch (URISyntaxException ex) {
throw new RuntimeException("Incorrect URI syntax", ex);
}
}

public HttpResponse changeInternalUserPassword(String jsonData, Header... headers) {
try {
HttpPut httpPut = new HttpPut(new URIBuilder(getHttpServerUri() + "/_plugins/_security/api/account?pretty").build());
httpPut.setEntity(new StringEntity(jsonData));
return executeRequest(httpPut, mergeHeaders(CONTENT_TYPE_JSON, headers));
} catch (URISyntaxException ex) {
throw new RuntimeException("Incorrect URI syntax", ex);
}
}

public void assertCorrectCredentials(String expectedUserName) {
public void confirmCorrectCredentials(String expectedUserName) {
HttpResponse response = getAuthInfo();
assertThat(response, notNullValue());
response.assertStatusCode(200);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -58,33 +57,25 @@ public class CreateOnBehalfOfTokenAction extends BaseRestHandler {

private ConfigModel configModel;

private DynamicConfigModel dcm;

public static final Integer OBO_DEFAULT_EXPIRY_SECONDS = 5 * 60;
public static final Integer OBO_MAX_EXPIRY_SECONDS = 10 * 60;

public static final String DEFAULT_SERVICE = "self-issued";

protected final Logger log = LogManager.getLogger(this.getClass());

private static final Set<String> RECOGNIZED_PARAMS = new HashSet<>(
RyanL1997 marked this conversation as resolved.
Show resolved Hide resolved
Arrays.asList("durationSeconds", "description", "roleSecurityMode", "service")
);

@Subscribe
public void onConfigModelChanged(ConfigModel configModel) {
public void onConfigModelChanged(final ConfigModel configModel) {
this.configModel = configModel;
}

@Subscribe
public void onDynamicConfigModelChanged(DynamicConfigModel dcm) {
this.dcm = dcm;
public void onDynamicConfigModelChanged(final DynamicConfigModel dcm) {
final Settings settings = dcm.getDynamicOnBehalfOfSettings();

Settings settings = dcm.getDynamicOnBehalfOfSettings();

Boolean enabled = Boolean.parseBoolean(settings.get("enabled"));
String signingKey = settings.get("signing_key");
String encryptionKey = settings.get("encryption_key");
final Boolean enabled = Boolean.parseBoolean(settings.get("enabled"));
final String signingKey = settings.get("signing_key");
final String encryptionKey = settings.get("encryption_key");

if (!Boolean.FALSE.equals(enabled) && signingKey != null && encryptionKey != null) {
this.vendor = new JwtVendor(settings, Optional.empty());
Expand All @@ -109,7 +100,7 @@ public List<Route> routes() {
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
protected RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
switch (request.method()) {
case POST:
return handlePost(request, client);
Expand All @@ -118,10 +109,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}

private RestChannelConsumer handlePost(RestRequest request, NodeClient client) throws IOException {
private RestChannelConsumer handlePost(final RestRequest request, final NodeClient client) throws IOException {
return new RestChannelConsumer() {
@Override
public void accept(RestChannel channel) throws Exception {
public void accept(final RestChannel channel) throws Exception {
final XContentBuilder builder = channel.newBuilder();
BytesRestResponse response;
try {
Expand All @@ -141,18 +132,14 @@ public void accept(RestChannel channel) throws Exception {

validateRequestParameters(requestBody);

Integer tokenDuration = parseAndValidateDurationSeconds(requestBody.get("durationSeconds"));
Integer tokenDuration = parseAndValidateDurationSeconds(requestBody.get(InputParameters.DURATION.paramName));
tokenDuration = Math.min(tokenDuration, OBO_MAX_EXPIRY_SECONDS);

final String description = (String) requestBody.getOrDefault("description", null);

final Boolean roleSecurityMode = Optional.ofNullable(requestBody.get("roleSecurityMode"))
peternied marked this conversation as resolved.
Show resolved Hide resolved
.map(value -> (Boolean) value)
.orElse(true); // Default to false if null
final String description = (String) requestBody.getOrDefault(InputParameters.DESCRIPTION.paramName, null);

final String service = (String) requestBody.getOrDefault("service", DEFAULT_SERVICE);
final String service = (String) requestBody.getOrDefault(InputParameters.SERVICE.paramName, DEFAULT_SERVICE);
final User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
Set<String> mappedRoles = mapRoles(user);
final Set<String> mappedRoles = mapRoles(user);

builder.startObject();
builder.field("user", user.getName());
Expand All @@ -164,14 +151,14 @@ public void accept(RestChannel channel) throws Exception {
tokenDuration,
mappedRoles.stream().collect(Collectors.toList()),
user.getRoles().stream().collect(Collectors.toList()),
roleSecurityMode
false
peternied marked this conversation as resolved.
Show resolved Hide resolved
);
builder.field("authenticationToken", token);
builder.field("durationSeconds", tokenDuration);
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (IllegalArgumentException iae) {
} catch (final IllegalArgumentException iae) {
builder.startObject().field("error", iae.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.BAD_REQUEST, builder);
} catch (final Exception exception) {
Expand All @@ -187,19 +174,32 @@ public void accept(RestChannel channel) throws Exception {
};
}

private enum InputParameters {
DURATION("durationSeconds"),
DESCRIPTION("description"),
SERVICE("service");

final String paramName;

private InputParameters(final String paramName) {
this.paramName = paramName;
}
}

private Set<String> mapRoles(final User user) {
return this.configModel.mapSecurityRoles(user, null);
}

private void validateRequestParameters(Map<String, Object> requestBody) throws IllegalArgumentException {
for (String key : requestBody.keySet()) {
if (!RECOGNIZED_PARAMS.contains(key)) {
throw new IllegalArgumentException("Unrecognized parameter: " + key);
}
private void validateRequestParameters(final Map<String, Object> requestBody) throws IllegalArgumentException {
for (final String key : requestBody.keySet()) {
Arrays.stream(InputParameters.values())
.filter(param -> param.paramName.equalsIgnoreCase(key))
.findAny()
.orElseThrow(() -> new IllegalArgumentException("Unrecognized parameter: " + key));
}
}

private Integer parseAndValidateDurationSeconds(Object durationObj) throws IllegalArgumentException {
private Integer parseAndValidateDurationSeconds(final Object durationObj) throws IllegalArgumentException {
if (durationObj == null) {
return OBO_DEFAULT_EXPIRY_SECONDS;
}
Expand All @@ -209,7 +209,7 @@ private Integer parseAndValidateDurationSeconds(Object durationObj) throws Illeg
} else if (durationObj instanceof String) {
try {
return Integer.parseInt((String) durationObj);
} catch (NumberFormatException ignored) {}
} catch (final NumberFormatException ignored) {}
}
throw new IllegalArgumentException("durationSeconds must be an integer.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public String createJwt(
Integer expirySeconds,
List<String> roles,
List<String> backendRoles,
boolean roleSecurityMode
boolean includeBackendRoles
peternied marked this conversation as resolved.
Show resolved Hide resolved
) throws JOSEException, ParseException {
final Date now = new Date(timeProvider.getAsLong());

Expand Down Expand Up @@ -139,7 +139,7 @@ public String createJwt(
throw new IllegalArgumentException("Roles cannot be null");
}

if (!roleSecurityMode && backendRoles != null) {
if (includeBackendRoles && backendRoles != null) {
String listOfBackendRoles = String.join(",", backendRoles);
claimsBuilder.claim("br", listOfBackendRoles);
}
Expand Down
Loading
Loading