diff --git a/src/integrationTest/java/org/opensearch/security/DefaultConfigurationTests.java b/src/integrationTest/java/org/opensearch/security/DefaultConfigurationTests.java index b254e3182a..c2fec40fd9 100644 --- a/src/integrationTest/java/org/opensearch/security/DefaultConfigurationTests.java +++ b/src/integrationTest/java/org/opensearch/security/DefaultConfigurationTests.java @@ -87,8 +87,7 @@ public void securityRolesUgrade() throws Exception { Awaitility.await().alias("Load default configuration").until(() -> client.getAuthInfo().getStatusCode(), equalTo(200)); final var defaultRolesResponse = client.get("_plugins/_security/api/roles/"); - final var roles = defaultRolesResponse.getBodyAs(JsonNode.class); - final var rolesCount = extractFieldNames(roles).size(); + final var rolesNames = extractFieldNames(defaultRolesResponse.getBodyAs(JsonNode.class)); final var checkForUpgrade = client.get("_plugins/_security/api/_upgrade_check"); System.out.println("checkForUpgrade Response: " + checkForUpgrade.getBody()); @@ -114,6 +113,12 @@ public void securityRolesUgrade() throws Exception { final var checkForUpgrade2 = client.get("_plugins/_security/api/_upgrade_check"); System.out.println("checkForUpgrade2 Response: " + checkForUpgrade2.getBody()); + final var upgradeResponse = client.post("_plugins/_security/api/_upgrade_perform"); + System.out.println("upgrade Response: " + upgradeResponse.getBody()); + + final var afterUpgradeRolesResponse = client.get("_plugins/_security/api/roles/"); + final var afterUpgradeRolesNames = extractFieldNames(defaultRolesResponse.getBodyAs(JsonNode.class)); + assertThat(afterUpgradeRolesNames, equalTo(rolesNames)); } } diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/ConfigUpgradeApiAction.java b/src/main/java/org/opensearch/security/dlic/rest/api/ConfigUpgradeApiAction.java index 6b81876af0..44b4a1656f 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/ConfigUpgradeApiAction.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/ConfigUpgradeApiAction.java @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.EnumSet; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -42,6 +43,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.security.configuration.ConfigurationRepository; +import org.opensearch.security.dlic.rest.api.RolesApiAction.RoleRequestContentValidator; import org.opensearch.security.dlic.rest.support.Utils; import org.opensearch.security.dlic.rest.validation.EndpointValidator; import org.opensearch.security.dlic.rest.validation.RequestContentValidator; @@ -58,6 +60,9 @@ import com.flipkart.zjsonpatch.DiffFlags; import com.flipkart.zjsonpatch.JsonDiff; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + import org.opensearch.common.inject.Inject; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.rest.RestStatus; @@ -85,9 +90,12 @@ public class ConfigUpgradeApiAction extends AbstractApiAction { private final static Logger LOGGER = LogManager.getLogger(ConfigUpgradeApiAction.class); + private final static Set SUPPORTED_CTYPES = ImmutableSet.of(CType.ROLES); + private static final List routes = addRoutesPrefix(ImmutableList.of( new Route(Method.GET, "/_upgrade_check"), - new Route(Method.POST, "/_upgrade_perform"))); + new Route(Method.POST, "/_upgrade_perform") + )); @Inject public ConfigUpgradeApiAction( @@ -103,14 +111,7 @@ public ConfigUpgradeApiAction( void handleCanUpgrade(final RestChannel channel, final RestRequest request, final Client client) throws IOException { withIOException(() -> getAndValidateConfigurationsToUpgrade(request) - .map(configurations -> { - final var differencesList = new ArrayList>>(); - for (final var configuration : configurations) { - differencesList.add(computeDifferenceToUpdate(configuration) - .map(differences -> ValidationResult.success(new Tuple(configuration, differences)))); - } - return ValidationResult.combine(differencesList); - })) + .map(this::configurationDifferences)) .valid(differencesList -> { final var canUpgrade = differencesList.stream().anyMatch(entry -> entry.v2().size() > 0); @@ -120,9 +121,9 @@ void handleCanUpgrade(final RestChannel channel, final RestRequest request, fina if (canUpgrade) { final ObjectNode differences = JsonNodeFactory.instance.objectNode(); differencesList.forEach(t -> { - differences.put(t.v1().toLCString(), t.v2()); + differences.set(t.v1().toLCString(), t.v2()); }); - response.put("differences", differences); + response.set("differences", differences); } channel.sendResponse(new BytesRestResponse(RestStatus.OK, XContentType.JSON.mediaType(), response.toPrettyString())); }) @@ -130,15 +131,57 @@ void handleCanUpgrade(final RestChannel channel, final RestRequest request, fina } private void handleUpgrade(final RestChannel channel, final RestRequest request, final Client client) throws IOException { - throw new UnsupportedOperationException("Unimplemented method 'handleUpgrade'"); + withIOException(() -> getConfigurations(request) + .map(this::configurationDifferences)) + .map(diffs -> applyDifferences(request, diffs)) + .valid(updatedResources -> { + ok(channel, "Applied all differences: " + updatedResources); + }) + .error((status, toXContent) -> response(channel, status, toXContent)); + } + + ValidationResult>>>> applyDifferences(final RestRequest request, final List> differencesToUpdate) throws IOException { + final var updatedResources = new ArrayList>>>>(); + for (final Tuple difference : differencesToUpdate) { + updatedResources.add(loadConfiguration(difference.v1(), false, false) + .map(configuration -> patchEntities(request, difference.v2(), SecurityConfiguration.of(null, configuration)) + .map(patchResults -> { + final var items = new HashMap(); + difference.v2().forEach(node -> { + final var item = pathRoot(node); + final var operation = node.get("op").asText(); + if (items.containsKey(item) && !items.get(item).equals(operation)) { + items.put(item, "modified"); + } else { + items.put(item, operation); + } + }); + + final var itemsGroupedByOperation = items.entrySet().stream().collect(Collectors.groupingBy(Map.Entry::getValue, Collectors.mapping(Map.Entry::getKey, Collectors.toList()))); + + return ValidationResult.success(new Tuple<>(difference.v1(), itemsGroupedByOperation)); + }) + ) + ); + } + + return ValidationResult.merge(updatedResources); + } + + private ValidationResult>> configurationDifferences(final Set configurations) throws IOException { + final var differences = new ArrayList>>(); + for (final var configuration : configurations) { + differences.add(computeDifferenceToUpdate(configuration)); + } + return ValidationResult.merge(differences); } - private ValidationResult computeDifferenceToUpdate(final CType configType) throws IOException { + private ValidationResult> computeDifferenceToUpdate(final CType configType) throws IOException { return loadConfiguration(configType, false, false).map(activeRoles -> { final var activeRolesJson = Utils.convertJsonToJackson(activeRoles, false); final var defaultRolesJson = loadConfigFileAsJson(configType); final var rawDiff = JsonDiff.asJson(activeRolesJson, defaultRolesJson, EnumSet.of(DiffFlags.OMIT_VALUE_ON_REMOVE)); - return ValidationResult.success(filterRemoveOperations(rawDiff)); + return ValidationResult.success(new Tuple<>(configType, filterRemoveOperations(rawDiff))); }); } @@ -147,22 +190,18 @@ private ValidationResult> getAndValidateConfigurationsToUpgrade(final final var configurations = Optional.ofNullable(configs) .map(CType::fromStringValues) - .orElse(supportedConfigs()); + .orElse(SUPPORTED_CTYPES); - if (!configurations.stream().allMatch(supportedConfigs()::contains)) { + if (!configurations.stream().allMatch(SUPPORTED_CTYPES::contains)) { // Remove all supported configurations - configurations.removeAll(supportedConfigs()); + configurations.removeAll(SUPPORTED_CTYPES); return ValidationResult.error(RestStatus.BAD_REQUEST, badRequestMessage("Unsupported configurations for upgrade" + configurations)); } return ValidationResult.success(configurations); } - private Set supportedConfigs() { - return Set.of(CType.ROLES); - } - - private JsonNode filterRemoveOperations(final JsonNode diff) { + JsonNode filterRemoveOperations(final JsonNode diff) { final ArrayNode filteredDiff = JsonNodeFactory.instance.arrayNode(); diff.forEach(node -> { if (!isRemoveOperation(node)) { @@ -177,6 +216,10 @@ private JsonNode filterRemoveOperations(final JsonNode diff) { return filteredDiff; } + private String pathRoot(final JsonNode node) { + return node.get("path").asText().split("/")[1]; + } + private boolean hasRootLevelPath(final JsonNode node) { final var jsonPath = node.get("path").asText(); return jsonPath.charAt(0) == '/' && !jsonPath.substring(1).contains("/"); @@ -207,7 +250,7 @@ public List routes() { @Override protected CType getConfigType() { - return CType.ROLES; + throw new UnsupportedOperationException("This class supports multiple configuration types"); } @Override @@ -226,18 +269,7 @@ public RestApiAdminPrivilegesEvaluator restApiAdminPrivilegesEvaluator() { @Override public RequestContentValidator createRequestContentValidator(final Object... params) { - return RequestContentValidator.of(new RequestContentValidator.ValidationContext() { - - @Override - public Set mandatoryKeys() { - return Set.of("configs"); - } - - @Override - public Map allowedKeys() { - return Map.of("configs", DataType.ARRAY); - } - + return new RoleRequestContentValidator(new RequestContentValidator.ValidationContext() { @Override public Object[] params() { return params; @@ -247,8 +279,33 @@ public Object[] params() { public Settings settings() { return securityApiDependencies.settings(); } + + @Override + public Map allowedKeys() { + return Map.of("configs", DataType.ARRAY); + } }); } }; } + + public static class ConfigUpgradeContentValidator extends RequestContentValidator { + + protected ConfigUpgradeContentValidator(final ValidationContext validationContext) { + super(validationContext); + } + + @Override + public ValidationResult validate(RestRequest request) throws IOException { + return super.validate(request); + } + + @Override + public ValidationResult validate(RestRequest request, JsonNode jsonContent) throws IOException { + return super.validate(request, jsonContent); + } + + + } + } diff --git a/src/main/java/org/opensearch/security/dlic/rest/validation/ValidationResult.java b/src/main/java/org/opensearch/security/dlic/rest/validation/ValidationResult.java index b0d2bf1a73..14c9d1546d 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/validation/ValidationResult.java +++ b/src/main/java/org/opensearch/security/dlic/rest/validation/ValidationResult.java @@ -11,16 +11,25 @@ package org.opensearch.security.dlic.rest.validation; +import static org.opensearch.security.dlic.rest.api.Responses.badRequestMessage; + import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.opensearch.common.CheckedBiConsumer; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.CheckedFunction; +import org.opensearch.common.collect.Tuple; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.security.securityconf.impl.CType; + +import com.fasterxml.jackson.databind.JsonNode; public class ValidationResult { @@ -52,15 +61,19 @@ public static ValidationResult error(final RestStatus status, final ToXCo return new ValidationResult<>(status, errorMessage); } - public static ValidationResult> combine(final List> entries) { - final var returnList = new ArrayList(); - for (final var entry : entries) { - if (!entry.isValid()) { - return error(entry.status(), entry.errorMessage()); - } - returnList.add(entry.content); + /** + * Transforms a list of validation results into a single validation result of that lists contents. + * If any of the validation results are not valid, the first is returned as the error. + */ + public static ValidationResult> merge(final List> results) { + if (results.stream().allMatch(ValidationResult::isValid)) { + return success(results.stream().map(result -> result.content).collect(Collectors.toList())); } - return success(returnList); + + return results.stream().filter(result -> !result.isValid()) + .map(failedResult -> new ValidationResult>(failedResult.status, failedResult.errorMessage)) + .findFirst() + .get(); } public ValidationResult map(final CheckedFunction, IOException> mapper) throws IOException { @@ -99,4 +112,9 @@ public ToXContent errorMessage() { public C getContent() { return content; } + + public ValidationResult>> map(Object mapper) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'map'"); + } }