From a82fbc93facd003c60cb8927ca7cc8de5e5af576 Mon Sep 17 00:00:00 2001 From: Surya Sashank Nistala Date: Wed, 6 Sep 2023 11:32:56 -0700 Subject: [PATCH] Using alerting workflows in detectors Signed-off-by: Subhobrata Dey --- .../SecurityAnalyticsPlugin.java | 3 +- .../securityanalytics/model/Detector.java | 47 +- .../settings/SecurityAnalyticsSettings.java | 6 + .../TransportDeleteDetectorAction.java | 130 ++-- .../transport/TransportGetFindingsAction.java | 2 +- .../TransportIndexDetectorAction.java | 617 +++++++++++------- .../util/MonitorService.java | 84 +++ .../util/WorkflowService.java | 185 ++++++ src/main/resources/mappings/detectors.json | 5 +- .../SecurityAnalyticsRestTestCase.java | 76 ++- .../securityanalytics/TestHelpers.java | 25 +- .../action/IndexDetectorResponseTests.java | 3 +- .../alerts/AlertingServiceTests.java | 6 +- .../findings/FindingServiceTests.java | 6 +- .../resthandler/DetectorMonitorRestApiIT.java | 468 ++++++++++++- .../resthandler/DetectorRestApiIT.java | 111 ++++ 16 files changed, 1472 insertions(+), 302 deletions(-) create mode 100644 src/main/java/org/opensearch/securityanalytics/util/MonitorService.java create mode 100644 src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index dd119dfc1..232b2ea97 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -242,7 +242,8 @@ public List> getSettings() { SecurityAnalyticsSettings.FINDING_HISTORY_RETENTION_PERIOD, SecurityAnalyticsSettings.IS_CORRELATION_INDEX_SETTING, SecurityAnalyticsSettings.CORRELATION_TIME_WINDOW, - SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA + SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA, + SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE ); } diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index 9c7e48362..276458882 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -52,6 +52,8 @@ public class Detector implements Writeable, ToXContentObject { public static final String ENABLED_TIME_FIELD = "enabled_time"; public static final String ALERTING_MONITOR_ID = "monitor_id"; + public static final String ALERTING_WORKFLOW_ID = "workflow_ids"; + public static final String BUCKET_MONITOR_ID_RULE_ID = "bucket_monitor_id_rule_id"; private static final String RULE_TOPIC_INDEX = "rule_topic_index"; @@ -99,6 +101,8 @@ public class Detector implements Writeable, ToXContentObject { private Map ruleIdMonitorIdMap; + private List workflowIds; + private String ruleIndex; private String alertsIndex; @@ -117,7 +121,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule Instant lastUpdateTime, Instant enabledTime, String logType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern, Map rulePerMonitor) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, List workflowIds) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -139,6 +143,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.findingsIndexPattern = findingsIndexPattern; this.ruleIdMonitorIdMap = rulePerMonitor; this.logType = logType; + this.workflowIds = workflowIds != null ? workflowIds : null; if (enabled) { Objects.requireNonNull(enabledTime); @@ -165,7 +170,8 @@ public Detector(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readString(), - sin.readMap(StreamInput::readString, StreamInput::readString) + sin.readMap(StreamInput::readString, StreamInput::readString), + sin.readStringList() ); } @@ -200,6 +206,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(ruleIndex); out.writeMap(ruleIdMonitorIdMap, StreamOutput::writeString, StreamOutput::writeString); + + if (workflowIds != null) { + out.writeStringCollection(workflowIds); + } } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -253,6 +263,14 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten } builder.field(ALERTING_MONITOR_ID, monitorIds); + + if (workflowIds == null) { + builder.nullField(ALERTING_WORKFLOW_ID); + } else { + builder.field(ALERTING_WORKFLOW_ID, workflowIds); + } + + builder.field(BUCKET_MONITOR_ID_RULE_ID, ruleIdMonitorIdMap); builder.field(RULE_TOPIC_INDEX, ruleIndex); builder.field(ALERTS_INDEX, alertsIndex); @@ -299,6 +317,7 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws List inputs = new ArrayList<>(); List triggers = new ArrayList<>(); List monitorIds = new ArrayList<>(); + List workflowIds = new ArrayList<>(); Map rulePerMonitor = new HashMap<>(); String ruleIndex = null; @@ -374,6 +393,15 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws monitorIds.add(monitorId); } break; + case ALERTING_WORKFLOW_ID: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + String workflowId = xcp.textOrNull(); + if (workflowId != null) { + workflowIds.add(workflowId); + } + } + break; case BUCKET_MONITOR_ID_RULE_ID: rulePerMonitor= xcp.mapStrings(); break; @@ -429,7 +457,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws alertsHistoryIndexPattern, findingsIndex, findingsIndexPattern, - rulePerMonitor + rulePerMonitor, + workflowIds ); } @@ -566,10 +595,22 @@ public void setRuleIdMonitorIdMap(Map ruleIdMonitorIdMap) { this.ruleIdMonitorIdMap = ruleIdMonitorIdMap; } + public void setWorkflowIds(List workflowIds) { + this.workflowIds = workflowIds; + } + + public List getWorkflowIds() { + return workflowIds; + } + public String getDocLevelMonitorId() { return ruleIdMonitorIdMap.get(DOC_LEVEL_MONITOR); } + public boolean isWorkflowSupported() { + return workflowIds != null && !workflowIds.isEmpty(); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java index 9a0aebc3c..43d358b85 100644 --- a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java +++ b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java @@ -98,6 +98,12 @@ public class SecurityAnalyticsSettings { Setting.Property.NodeScope, Setting.Property.Dynamic ); + public static final Setting ENABLE_WORKFLOW_USAGE = Setting.boolSetting( + "plugins.security_analytics.enable_workflow_usage", + false, + Setting.Property.NodeScope, Setting.Property.Dynamic + ); + public static final Setting IS_CORRELATION_INDEX_SETTING = Setting.boolSetting(CORRELATION_INDEX, false, Setting.Property.IndexScope); public static final Setting CORRELATION_TIME_WINDOW = Setting.positiveTimeSetting( diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteDetectorAction.java index f23a0d9c7..1e8a9880d 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteDetectorAction.java @@ -4,17 +4,12 @@ */ package org.opensearch.securityanalytics.transport; -import java.util.Collection; -import java.util.List; -import java.util.Locale; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.SetOnce; import org.opensearch.OpenSearchStatusException; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.ActionRunnable; +import org.opensearch.action.StepListener; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; @@ -25,28 +20,41 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.AlertingPluginInterface; import org.opensearch.commons.alerting.action.DeleteMonitorRequest; import org.opensearch.commons.alerting.action.DeleteMonitorResponse; +import org.opensearch.commons.alerting.action.DeleteWorkflowResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.rest.RestStatus; +import org.opensearch.extensions.AcknowledgedResponse; import org.opensearch.securityanalytics.action.DeleteDetectorAction; import org.opensearch.securityanalytics.action.DeleteDetectorRequest; import org.opensearch.securityanalytics.action.DeleteDetectorResponse; import org.opensearch.securityanalytics.mapper.IndexTemplateManager; import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.util.DetectorIndices; +import org.opensearch.securityanalytics.util.MonitorService; import org.opensearch.securityanalytics.util.RuleTopicIndices; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.securityanalytics.util.WorkflowService; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.securityanalytics.model.Detector.NO_VERSION; @@ -60,14 +68,24 @@ public class TransportDeleteDetectorAction extends HandledTransportAction monitorIds = detector.getMonitorIds(); - ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { - @Override - public void onResponse(Collection responses) { - SetOnce errorStatusSupplier = new SetOnce<>(); - if (responses.stream().filter(response -> { - if (response.getStatus() != RestStatus.OK) { - log.error("Detector not being deleted because monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); - errorStatusSupplier.trySet(response.getStatus()); - return true; + StepListener onDeleteWorkflowStep = new StepListener<>(); + // 1. Delete the workflow if the workflow is supported + deleteWorkflow(detector, onDeleteWorkflowStep); + onDeleteWorkflowStep.whenComplete(acknowledgedResponse -> { + List monitorIds = detector.getMonitorIds(); + ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection responses) { + SetOnce errorStatusSupplier = new SetOnce<>(); + if (responses.stream().filter(response -> { + if (response.getStatus() != RestStatus.OK) { + log.error("Detector not being deleted because monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); + errorStatusSupplier.trySet(response.getStatus()); + return true; + } + return false; + }).count() > 0) { + onFailures(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); } - return false; - }).count() > 0) { - onFailures(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); + deleteDetectorFromConfig(detector.getId(), request.getRefreshPolicy()); } - deleteDetectorFromConfig(detector.getId(), request.getRefreshPolicy()); - } - @Override - public void onFailure(Exception e) { - if(isOnlyMonitorOrIndexMissingExceptionThrownByGroupedActionListener(e, detector.getId())) { - deleteDetectorFromConfig(detector.getId(), request.getRefreshPolicy()); - } else { - log.error(String.format(Locale.ROOT, "Failed to delete detector %s", detector.getId()), e); - if (counter.compareAndSet(false, true)) { - finishHim(null, e); + @Override + public void onFailure(Exception e) { + if (isOnlyMonitorOrIndexMissingExceptionThrownByGroupedActionListener(e, detector.getId())) { + deleteDetectorFromConfig(detector.getId(), request.getRefreshPolicy()); + } else { + log.error(String.format(Locale.ROOT, "Failed to delete detector %s", detector.getId()), e); + if (counter.compareAndSet(false, true)) { + finishHim(null, e); + } } } + }, monitorIds.size()); + for (String monitorId : monitorIds) { + deleteAlertingMonitor(monitorId, request.getRefreshPolicy(), + deletesListener); + } + }, e -> { + if (counter.compareAndSet(false, true)) { + finishHim(null, e); } - }, monitorIds.size()); - for (String monitorId : monitorIds) { - deleteAlertingMonitor(monitorId, request.getRefreshPolicy(), - deletesListener); + }); + } + + private void deleteWorkflow(Detector detector, ActionListener actionListener) { + if (detector.isWorkflowSupported() && enabledWorkflowUsage) { + var workflowId = detector.getWorkflowIds().get(0); + log.debug(String.format("Deleting the workflow %s before deleting the detector", workflowId)); + StepListener onDeleteWorkflowStep = new StepListener<>(); + workflowService.deleteWorkflow(workflowId, onDeleteWorkflowStep); + onDeleteWorkflowStep.whenComplete(deleteWorkflowResponse -> { + actionListener.onResponse(new AcknowledgedResponse(true)); + }, actionListener::onFailure); + } else { + // If detector doesn't have the workflows it means that older version of the plugin is used and just skip the step + actionListener.onResponse(new AcknowledgedResponse(true)); } } @@ -211,6 +260,7 @@ public void onFailure(Exception e) { }); } + @Override public void onFailure(Exception t) { onFailures(t); @@ -235,7 +285,7 @@ private void onFailures(Exception t) { private void finishHim(String detectorId, Exception t) { threadPool.executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> { if (t != null) { - log.error(String.format(Locale.ROOT, "Failed to delete detector %s",detectorId), t); + log.error(String.format(Locale.ROOT, "Failed to delete detector %s", detectorId), t); if (t instanceof OpenSearchStatusException) { throw t; } @@ -256,8 +306,8 @@ private boolean isOnlyMonitorOrIndexMissingExceptionThrownByGroupedActionListene for (int i = 0; i <= len; i++) { Throwable e = i == len ? ex : ex.getSuppressed()[i]; if (e.getMessage().matches("(.*)Monitor(.*) is not found(.*)") - || e.getMessage().contains( - "Configured indices are not found: [.opendistro-alerting-config]") + || e.getMessage().contains( + "Configured indices are not found: [.opendistro-alerting-config]") ) { log.error( String.format(Locale.ROOT, "Monitor or jobs index already deleted." + @@ -270,4 +320,8 @@ private boolean isOnlyMonitorOrIndexMissingExceptionThrownByGroupedActionListene return true; } } + + private void setEnabledWorkflowUsage(boolean enabledWorkflowUsage) { + this.enabledWorkflowUsage = enabledWorkflowUsage; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java index f9e7856db..de54400db 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java @@ -111,7 +111,7 @@ protected void doExecute(Task task, GetFindingsRequest request, ActionListener> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { + private void createMonitorFromQueries(List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( Collectors.toList()); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( @@ -264,27 +272,26 @@ private void createMonitorFromQueries(String index, List> rul AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, addFirstMonitorStep); addFirstMonitorStep.whenComplete(addedFirstMonitorResponse -> { monitorResponses.add(addedFirstMonitorResponse); + + StepListener> indexMonitorsStep = new StepListener<>(); + indexMonitorsStep.whenComplete( + indexMonitorResponses -> saveWorkflow(detector, indexMonitorResponses, refreshPolicy, listener), + e -> { + log.error("Failed to index the workflow", e); + listener.onFailure(e); + }); + int numberOfUnprocessedResponses = monitorRequests.size() - 1; if (numberOfUnprocessedResponses == 0) { - listener.onResponse(monitorResponses); + saveWorkflow(detector, monitorResponses, refreshPolicy, listener); } else { - GroupedActionListener monitorResponseListener = new GroupedActionListener( - new ActionListener>() { - @Override - public void onResponse(Collection indexMonitorResponse) { - monitorResponses.addAll(indexMonitorResponse.stream().collect(Collectors.toList())); - listener.onResponse(monitorResponses); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }, numberOfUnprocessedResponses); - - for (int i = 1; i < monitorRequests.size(); i++) { - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(i), namedWriteableRegistry, monitorResponseListener); - } + // Saves the rest of the monitors and saves the workflow if supported + saveMonitors( + monitorRequests, + monitorResponses, + numberOfUnprocessedResponses, + indexMonitorsStep + ); } }, listener::onFailure @@ -298,41 +305,85 @@ public void onFailure(Exception e) { } List monitorResponses = new ArrayList<>(); - StepListener addFirstMonitorStep = new StepListener(); + StepListener indexDocLevelMonitorStep = new StepListener(); // Indexing monitors in two steps in order to prevent all shards failed error from alerting // https://github.com/opensearch-project/alerting/issues/646 - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, addFirstMonitorStep); - addFirstMonitorStep.whenComplete(addedFirstMonitorResponse -> { + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, indexDocLevelMonitorStep); + indexDocLevelMonitorStep.whenComplete(addedFirstMonitorResponse -> { monitorResponses.add(addedFirstMonitorResponse); - int numberOfUnprocessedResponses = monitorRequests.size() - 1; - if (numberOfUnprocessedResponses == 0) { - listener.onResponse(monitorResponses); - } else { - GroupedActionListener monitorResponseListener = new GroupedActionListener( - new ActionListener>() { - @Override - public void onResponse(Collection indexMonitorResponse) { - monitorResponses.addAll(indexMonitorResponse.stream().collect(Collectors.toList())); - listener.onResponse(monitorResponses); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }, numberOfUnprocessedResponses); - - for (int i = 1; i < monitorRequests.size(); i++) { - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(i), namedWriteableRegistry, monitorResponseListener); - } - } + saveWorkflow(detector, monitorResponses, refreshPolicy, listener); }, listener::onFailure ); } } + private void saveMonitors( + List monitorRequests, + List monitorResponses, + int numberOfUnprocessedResponses, + ActionListener> listener + ) { + GroupedActionListener monitorResponseListener = new GroupedActionListener( + new ActionListener>() { + @Override + public void onResponse(Collection indexMonitorResponses) { + monitorResponses.addAll(indexMonitorResponses.stream().collect(Collectors.toList())); + listener.onResponse(monitorResponses); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, numberOfUnprocessedResponses); + + for (int i = 1; i < monitorRequests.size(); i++) { + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(i), namedWriteableRegistry, monitorResponseListener); + } + } + + /** + * If the workflow is enabled, saves the workflow, updates the detector and returns the saved monitors + * if not, returns the saved monitors + * @param detector + * @param monitorResponses + * @param refreshPolicy + * @param actionListener + */ + private void saveWorkflow( + Detector detector, + List monitorResponses, + RefreshPolicy refreshPolicy, + ActionListener> actionListener + ) { + if (enabledWorkflowUsage) { + workflowService.upsertWorkflow( + monitorResponses.stream().map(IndexMonitorResponse::getId).collect(Collectors.toList()), + null, + detector, + refreshPolicy, + Workflow.NO_ID, + Method.POST, + new ActionListener<>() { + @Override + public void onResponse(IndexWorkflowResponse workflowResponse) { + // Update passed detector with the workflowId + detector.setWorkflowIds(List.of(workflowResponse.getId())); + actionListener.onResponse(monitorResponses); + } + + @Override + public void onFailure(Exception e) { + log.error("Error saving workflow", e); + actionListener.onFailure(e); + } + }); + } else { + actionListener.onResponse(monitorResponses); + } + } + private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { List monitorsToBeUpdated = new ArrayList<>(); @@ -395,7 +446,7 @@ public void onResponse(Map> ruleFieldMappings) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } catch (IOException | SigmaError ex) { listener.onFailure(ex); } @@ -423,7 +474,7 @@ public void onFailure(Exception e) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } } @@ -432,8 +483,9 @@ public void onFailure(Exception e) { * Executed in a steps: * 1. Add new monitors; * 2. Update existing monitors; - * 3. Delete the monitors omitted from request - * 4. Respond with updated list of monitors + * 3. Updates the workflow + * 4. Delete the monitors omitted from request + * 5. Respond with updated list of monitors * @param monitorsToBeAdded Newly added monitors by the user * @param monitorsToBeUpdated Existing monitors that will be updated * @param monitorsToBeDeleted Monitors omitted by the user @@ -441,6 +493,7 @@ public void onFailure(Exception e) { * @param listener Listener that accepts the list of updated monitors if the action was successful */ private void updateAlertingMonitors( + Detector detector, List monitorsToBeAdded, List monitorsToBeUpdated, List monitorsToBeDeleted, @@ -457,29 +510,112 @@ private void updateAlertingMonitors( if (addNewMonitorsResponse != null && !addNewMonitorsResponse.isEmpty()) { updatedMonitors.addAll(addNewMonitorsResponse); } - StepListener> updateMonitorsStep = new StepListener<>(); executeMonitorActionRequest(monitorsToBeUpdated, updateMonitorsStep); // 2. Update existing alerting monitors (based on the common rules) updateMonitorsStep.whenComplete(updateMonitorResponse -> { - if (updateMonitorResponse!=null && !updateMonitorResponse.isEmpty()) { - updatedMonitors.addAll(updateMonitorResponse); - } - - StepListener> deleteMonitorStep = new StepListener<>(); - deleteAlertingMonitors(monitorsToBeDeleted, refreshPolicy, deleteMonitorStep); - // 3. Delete alerting monitors (rules that are not provided by the user) - deleteMonitorStep.whenComplete(deleteMonitorResponses -> - // Return list of all updated + newly added monitors - listener.onResponse(updatedMonitors), - // Handle delete monitors (step 3) - listener::onFailure); - }, // Handle update monitor failed (step 2) + if (updateMonitorResponse != null && !updateMonitorResponse.isEmpty()) { + updatedMonitors.addAll(updateMonitorResponse); + } + if (detector.isWorkflowSupported() && enabledWorkflowUsage) { + updateWorkflowStep( + detector, + monitorsToBeDeleted, + refreshPolicy, + listener, + updatedMonitors, + addNewMonitorsResponse, + updateMonitorResponse + ); + } else { + deleteMonitorStep(monitorsToBeDeleted, refreshPolicy, updatedMonitors, listener); + } + }, + // Handle update monitor failed (step 2) listener::onFailure); // Handle add failed (step 1) }, listener::onFailure); } + private void deleteMonitorStep( + List monitorsToBeDeleted, + RefreshPolicy refreshPolicy, + List updatedMonitors, + ActionListener> listener + ) { + monitorService.deleteAlertingMonitors(monitorsToBeDeleted, + refreshPolicy, + new ActionListener<>() { + @Override + public void onResponse(List deleteMonitorResponses) { + listener.onResponse(updatedMonitors); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete the monitors", e); + listener.onFailure(e); + } + }); + } + + private void updateWorkflowStep( + Detector detector, + List monitorsToBeDeleted, + RefreshPolicy refreshPolicy, + ActionListener> listener, + List updatedMonitors, + List addNewMonitorsResponse, + List updateMonitorResponse + ) { + List addedMonitorIds = addNewMonitorsResponse.stream().map(IndexMonitorResponse::getId) + .collect(Collectors.toList()); + List updatedMonitorIds = updateMonitorResponse.stream().map(IndexMonitorResponse::getId) + .collect(Collectors.toList()); + + // If there are no added or updated monitors - all monitors should be deleted + // Before deleting the monitors, workflow should be removed so there are no monitors that are part of the workflow + // which means that the workflow should be removed + if (addedMonitorIds.isEmpty() && updatedMonitorIds.isEmpty()) { + workflowService.deleteWorkflow( + detector.getWorkflowIds().get(0), + new ActionListener<>() { + @Override + public void onResponse(DeleteWorkflowResponse deleteWorkflowResponse) { + detector.setWorkflowIds(Collections.emptyList()); + deleteMonitorStep(monitorsToBeDeleted, refreshPolicy, updatedMonitors, listener); + } + @Override + public void onFailure(Exception e) { + log.error("Failed to delete the workflow", e); + listener.onFailure(e); + } + } + ); + + } else { + // Update workflow and delete the monitors + workflowService.upsertWorkflow( + addedMonitorIds, + updatedMonitorIds, + detector, + refreshPolicy, + detector.getWorkflowIds().get(0), + Method.PUT, + new ActionListener<>() { + @Override + public void onResponse(IndexWorkflowResponse workflowResponse) { + deleteMonitorStep(monitorsToBeDeleted, refreshPolicy, updatedMonitors, listener); + } + @Override + public void onFailure(Exception e) { + log.error("Failed to update the workflow"); + listener.onFailure(e); + } + }); + } + } + private IndexMonitorRequest createDocLevelMonitorRequest(List> queries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) { List docLevelMonitorInputs = new ArrayList<>(); @@ -517,7 +653,7 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List triggers.add(new DocumentLevelTrigger(id, name, severity, actions, condition)); } - Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, Monitor.MonitorType.DOC_LEVEL_MONITOR, detector.getUser(), 1, docLevelMonitorInputs, triggers, Map.of(), new DataSources(detector.getRuleIndex(), detector.getFindingsIndex(), @@ -595,36 +731,36 @@ private IndexMonitorRequest createBucketLevelMonitorRequest( .aggregation(aggregationQueries.getAggBuilder()); // input index can also be an index pattern or alias so we have to resolve it to concrete index String concreteIndex = IndexUtils.getNewIndexByCreationDate( - clusterService.state(), - indexNameExpressionResolver, - indices.get(0) // taking first one is fine because we expect that all indices in list share same mappings + clusterService.state(), + indexNameExpressionResolver, + indices.get(0) // taking first one is fine because we expect that all indices in list share same mappings ); try { GetIndexMappingsResponse getIndexMappingsResponse = client.execute( GetIndexMappingsAction.INSTANCE, new GetIndexMappingsRequest(concreteIndex)) - .actionGet(); + .actionGet(); MappingMetadata mappingMetadata = getIndexMappingsResponse.mappings().get(concreteIndex); List> pairs = MapperUtils.getAllAliasPathPairs(mappingMetadata); boolean timeStampAliasPresent = pairs. - stream() - .anyMatch(p -> - TIMESTAMP_FIELD_ALIAS.equals(p.getLeft()) || TIMESTAMP_FIELD_ALIAS.equals(p.getRight())); + stream() + .anyMatch(p -> + TIMESTAMP_FIELD_ALIAS.equals(p.getLeft()) || TIMESTAMP_FIELD_ALIAS.equals(p.getRight())); if(timeStampAliasPresent) { BoolQueryBuilder boolQueryBuilder = searchSourceBuilder.query() == null - ? new BoolQueryBuilder() - : QueryBuilders.boolQuery().must(searchSourceBuilder.query()); + ? new BoolQueryBuilder() + : QueryBuilders.boolQuery().must(searchSourceBuilder.query()); RangeQueryBuilder timeRangeFilter = QueryBuilders.rangeQuery(TIMESTAMP_FIELD_ALIAS) - .gt("{{period_end}}||-1h") - .lte("{{period_end}}") - .format("epoch_millis"); + .gt("{{period_end}}||-1h") + .lte("{{period_end}}") + .format("epoch_millis"); boolQueryBuilder.must(timeRangeFilter); searchSourceBuilder.query(boolQueryBuilder); } } catch (Exception e) { log.error( - String.format(Locale.getDefault(), - "Unable to verify presence of timestamp alias for index [%s] in detector [%s]. Not setting time range filter for bucket level monitor.", + String.format(Locale.getDefault(), + "Unable to verify presence of timestamp alias for index [%s] in detector [%s]. Not setting time range filter for bucket level monitor.", concreteIndex, detector.getName()), e); } @@ -649,7 +785,7 @@ private IndexMonitorRequest createBucketLevelMonitorRequest( triggers.add(bucketLevelTrigger1); } **/ - Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), new DataSources(detector.getRuleIndex(), detector.getFindingsIndex(), @@ -696,48 +832,6 @@ public void onFailure(Exception e) { } } - /** - * Deletes the alerting monitors based on the given ids and notifies the listener that will be notified once all monitors have been deleted - * @param monitorIds monitor ids to be deleted - * @param refreshPolicy - * @param listener listener that will be notified once all the monitors are being deleted - */ - private void deleteAlertingMonitors(List monitorIds, WriteRequest.RefreshPolicy refreshPolicy, ActionListener> listener){ - if (monitorIds == null || monitorIds.isEmpty()) { - listener.onResponse(new ArrayList<>()); - return; - } - ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { - @Override - public void onResponse(Collection responses) { - SetOnce errorStatusSupplier = new SetOnce<>(); - if (responses.stream().filter(response -> { - if (response.getStatus() != RestStatus.OK) { - log.error("Monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); - errorStatusSupplier.trySet(response.getStatus()); - return true; - } - return false; - }).count() > 0) { - listener.onFailure(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); - } - listener.onResponse(responses.stream().collect(Collectors.toList())); - } - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }, monitorIds.size()); - - for (String monitorId : monitorIds) { - deleteAlertingMonitor(monitorId, refreshPolicy, deletesListener); - } - } - private void deleteAlertingMonitor(String monitorId, WriteRequest.RefreshPolicy refreshPolicy, ActionListener listener) { - DeleteMonitorRequest request = new DeleteMonitorRequest(monitorId, refreshPolicy); - AlertingPluginInterface.INSTANCE.deleteMonitor((NodeClient) client, request, listener); - } - private void onCreateMappingsResponse(CreateIndexResponse response) throws IOException { if (response.isAcknowledged()) { log.info(String.format(Locale.getDefault(), "Created %s with mappings.", Detector.DETECTORS_INDEX)); @@ -919,19 +1013,19 @@ public void onResponse(GetResponse response) { try { XContentParser xcp = XContentHelper.createParser( - xContentRegistry, LoggingDeprecationHandler.INSTANCE, - response.getSourceAsBytesRef(), XContentType.JSON + xContentRegistry, LoggingDeprecationHandler.INSTANCE, + response.getSourceAsBytesRef(), XContentType.JSON ); Detector detector = Detector.docParse(xcp, response.getId(), response.getVersion()); // security is enabled and filterby is enabled if (!checkUserPermissionsWithResource( - originalContextUser, - detector.getUser(), - "detector", - detector.getId(), - TransportIndexDetectorAction.this.filterByEnabled + originalContextUser, + detector.getUser(), + "detector", + detector.getId(), + TransportIndexDetectorAction.this.filterByEnabled ) ) { @@ -957,6 +1051,7 @@ void onGetResponse(Detector currentDetector, User user) { } request.getDetector().setMonitorIds(currentDetector.getMonitorIds()); request.getDetector().setRuleIdMonitorIdMap(currentDetector.getRuleIdMonitorIdMap()); + request.getDetector().setWorkflowIds(currentDetector.getWorkflowIds()); Detector detector = request.getDetector(); String ruleTopic = detector.getDetectorType(); @@ -1009,11 +1104,41 @@ public void onFailure(Exception e) { public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionListener> listener) { ruleIndices.initPrepackagedRulesIndex( - new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - ruleIndices.onCreateMappingsResponse(response, true); - ruleIndices.importRules(RefreshPolicy.IMMEDIATE, indexTimeout, + new ActionListener<>() { + @Override + public void onResponse(CreateIndexResponse response) { + ruleIndices.onCreateMappingsResponse(response, true); + ruleIndices.importRules(RefreshPolicy.IMMEDIATE, indexTimeout, + new ActionListener<>() { + @Override + public void onResponse(BulkResponse response) { + if (!response.hasFailures()) { + importRules(request, listener); + } else { + onFailures(new OpenSearchStatusException(response.buildFailureMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }); + } + + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }, + new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse response) { + ruleIndices.onUpdateMappingsResponse(response, true); + ruleIndices.deleteRules(new ActionListener<>() { + @Override + public void onResponse(BulkByScrollResponse response) { + ruleIndices.importRules(WriteRequest.RefreshPolicy.IMMEDIATE, indexTimeout, new ActionListener<>() { @Override public void onResponse(BulkResponse response) { @@ -1029,85 +1154,55 @@ public void onFailure(Exception e) { onFailures(e); } }); - } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }, - new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse response) { - ruleIndices.onUpdateMappingsResponse(response, true); - ruleIndices.deleteRules(new ActionListener<>() { - @Override - public void onResponse(BulkByScrollResponse response) { - ruleIndices.importRules(WriteRequest.RefreshPolicy.IMMEDIATE, indexTimeout, - new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (!response.hasFailures()) { - importRules(request, listener); - } else { - onFailures(new OpenSearchStatusException(response.buildFailureMessage(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }); + } - @Override - public void onFailure(Exception e) { - onFailures(e); + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }, + new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + if (response.isTimedOut()) { + onFailures(new OpenSearchStatusException(response.toString(), RestStatus.REQUEST_TIMEOUT)); } - }, - new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException(response.toString(), RestStatus.REQUEST_TIMEOUT)); - } - long count = response.getHits().getTotalHits().value; - if (count == 0) { - ruleIndices.importRules(WriteRequest.RefreshPolicy.IMMEDIATE, indexTimeout, - new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (!response.hasFailures()) { - importRules(request, listener); - } else { - onFailures(new OpenSearchStatusException(response.buildFailureMessage(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } + long count = response.getHits().getTotalHits().value; + if (count == 0) { + ruleIndices.importRules(WriteRequest.RefreshPolicy.IMMEDIATE, indexTimeout, + new ActionListener<>() { + @Override + public void onResponse(BulkResponse response) { + if (!response.hasFailures()) { + importRules(request, listener); + } else { + onFailures(new OpenSearchStatusException(response.buildFailureMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } else { - importRules(request, listener); - } + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }); + } else { + importRules(request, listener); } + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } + @Override + public void onFailure(Exception e) { + onFailures(e); } + } ); } @@ -1121,14 +1216,14 @@ public void importRules(IndexDetectorRequest request, ActionListener ruleIds = detectorInput.getPrePackagedRules().stream().map(DetectorRule::getId).collect(Collectors.toList()); QueryBuilder queryBuilder = - QueryBuilders.nestedQuery("rule", - QueryBuilders.boolQuery().must( - QueryBuilders.matchQuery("rule.category", ruleTopic) - ).must( - QueryBuilders.termsQuery("_id", ruleIds.toArray(new String[]{})) - ), - ScoreMode.Avg - ); + QueryBuilders.nestedQuery("rule", + QueryBuilders.boolQuery().must( + QueryBuilders.matchQuery("rule.category", ruleTopic) + ).must( + QueryBuilders.termsQuery("_id", ruleIds.toArray(new String[]{})) + ), + ScoreMode.Avg + ); SearchRequest searchRequest = new SearchRequest(Rule.PRE_PACKAGED_RULES_INDEX) .source(new SearchSourceBuilder() @@ -1151,8 +1246,8 @@ public void onResponse(SearchResponse response) { try { for (SearchHit hit: hits) { XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() ); Rule rule = Rule.docParse(xcp, hit.getId(), hit.getVersion()); @@ -1167,7 +1262,7 @@ public void onResponse(SearchResponse response) { onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.NOT_FOUND)); } else { if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); + createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } @@ -1210,8 +1305,8 @@ public void onResponse(SearchResponse response) { try { for (SearchHit hit : hits) { XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() ); Rule rule = Rule.docParse(xcp, hit.getId(), hit.getVersion()); @@ -1221,7 +1316,7 @@ public void onResponse(SearchResponse response) { } if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); + createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } @@ -1241,15 +1336,15 @@ public void indexDetector() throws IOException { IndexRequest indexRequest; if (request.getMethod() == RestRequest.Method.POST) { indexRequest = new IndexRequest(Detector.DETECTORS_INDEX) - .setRefreshPolicy(request.getRefreshPolicy()) - .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) - .timeout(indexTimeout); + .setRefreshPolicy(request.getRefreshPolicy()) + .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) + .timeout(indexTimeout); } else { indexRequest = new IndexRequest(Detector.DETECTORS_INDEX) - .setRefreshPolicy(request.getRefreshPolicy()) - .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) - .id(request.getDetectorId()) - .timeout(indexTimeout); + .setRefreshPolicy(request.getRefreshPolicy()) + .source(request.getDetector().toXContentWithUser(XContentFactory.jsonBuilder(), new ToXContent.MapParams(Map.of("with_type", "true")))) + .id(request.getDetectorId()) + .timeout(indexTimeout); } client.index(indexRequest, new ActionListener<>() { @@ -1262,7 +1357,30 @@ public void onResponse(IndexResponse response) { @Override public void onFailure(Exception e) { - onFailures(e); + // Revert the workflow and monitors created in previous steps + workflowService.deleteWorkflow(request.getDetector().getWorkflowIds().get(0), + new ActionListener<>() { + @Override + public void onResponse(DeleteWorkflowResponse deleteWorkflowResponse) { + monitorService.deleteAlertingMonitors(request.getDetector().getMonitorIds(), + request.getRefreshPolicy(), + new ActionListener<>() { + @Override + public void onResponse(List deleteMonitorResponses) { + onFailures(e); + } + + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }); + } + @Override + public void onFailure(Exception e) { + onFailures(e); + } + }); } }); } @@ -1307,18 +1425,18 @@ private List getMonitorIds(List monitorResponses) */ private Map mapMonitorIds(List monitorResponses) { return monitorResponses.stream().collect( - Collectors.toMap( - // In the case of bucket level monitors rule id is trigger id - it -> { - if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { - return it.getMonitor().getTriggers().get(0).getId(); - } else { - return Detector.DOC_LEVEL_MONITOR; - } - }, - IndexMonitorResponse::getId - ) - ); + Collectors.toMap( + // In the case of bucket level monitors rule id is trigger id + it -> { + if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { + return it.getMonitor().getTriggers().get(0).getId(); + } else { + return Detector.DOC_LEVEL_MONITOR; + } + }, + IndexMonitorResponse::getId + ) + ); } } @@ -1326,4 +1444,7 @@ private void setFilterByEnabled(boolean filterByEnabled) { this.filterByEnabled = filterByEnabled; } -} \ No newline at end of file + private void setEnabledWorkflowUsage(boolean enabledWorkflowUsage) { + this.enabledWorkflowUsage = enabledWorkflowUsage; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/util/MonitorService.java b/src/main/java/org/opensearch/securityanalytics/util/MonitorService.java new file mode 100644 index 000000000..e0fa7163d --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/util/MonitorService.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.SetOnce; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.DeleteMonitorRequest; +import org.opensearch.commons.alerting.action.DeleteMonitorResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Alerting common class used for monitors manipulation + */ +public class MonitorService { + private static final Logger log = LogManager.getLogger(MonitorService.class); + + private Client client; + + public MonitorService() { + } + + public MonitorService(Client client) { + this.client = client; + } + + /** + * Deletes the alerting monitors based on the given ids and notifies the listener that will be notified once all monitors have been deleted + * @param monitorIds monitor ids to be deleted + * @param refreshPolicy + * @param listener listener that will be notified once all the monitors are being deleted + */ + public void deleteAlertingMonitors(List monitorIds, WriteRequest.RefreshPolicy refreshPolicy, ActionListener> listener){ + if (monitorIds == null || monitorIds.isEmpty()) { + listener.onResponse(new ArrayList<>()); + return; + } + ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection responses) { + SetOnce errorStatusSupplier = new SetOnce<>(); + if (responses.stream().filter(response -> { + if (response.getStatus() != RestStatus.OK) { + log.error("Monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); + errorStatusSupplier.trySet(response.getStatus()); + return true; + } + return false; + }).count() > 0) { + listener.onFailure(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); + } + listener.onResponse(responses.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + log.error("Error deleting monitors", e.getSuppressed()); + listener.onFailure(e); + } + }, monitorIds.size()); + + for (String monitorId : monitorIds) { + deleteAlertingMonitor(monitorId, refreshPolicy, deletesListener); + } + } + + private void deleteAlertingMonitor(String monitorId, WriteRequest.RefreshPolicy refreshPolicy, ActionListener listener) { + DeleteMonitorRequest request = new DeleteMonitorRequest(monitorId, refreshPolicy); + AlertingPluginInterface.INSTANCE.deleteMonitor((NodeClient) client, request, listener); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java new file mode 100644 index 000000000..e75e17fe8 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.WriteRequest.RefreshPolicy; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.DeleteMonitorResponse; +import org.opensearch.commons.alerting.action.DeleteWorkflowRequest; +import org.opensearch.commons.alerting.action.DeleteWorkflowResponse; +import org.opensearch.commons.alerting.action.IndexMonitorResponse; +import org.opensearch.commons.alerting.action.IndexWorkflowRequest; +import org.opensearch.commons.alerting.action.IndexWorkflowResponse; +import org.opensearch.commons.alerting.model.CompositeInput; +import org.opensearch.commons.alerting.model.Delegate; +import org.opensearch.commons.alerting.model.Monitor.MonitorType; +import org.opensearch.commons.alerting.model.Sequence; +import org.opensearch.commons.alerting.model.Workflow; +import org.opensearch.commons.alerting.model.Workflow.WorkflowType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.securityanalytics.model.Detector; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +/** + * Alerting common clas used for workflow manipulation + */ +public class WorkflowService { + private static final Logger log = LogManager.getLogger(WorkflowService.class); + private Client client; + + private MonitorService monitorService; + + public WorkflowService() { + } + + public WorkflowService(Client client, MonitorService monitorService) { + this.client = client; + this.monitorService = monitorService; + } + + /** + * Upserts the workflow - depending on the method and lists forwarded; If the method is put and updated + * If the workflow upsert failed, deleting monitors will be performed + * @param addedMonitors monitors to be added + * @param updatedMonitors monitors to be updated + * @param detector detector for which monitors needs to be added/updated + * @param refreshPolicy + * @param workflowId + * @param method http method POST/PUT + * @param listener + */ + public void upsertWorkflow( + List addedMonitors, + List updatedMonitors, + Detector detector, + RefreshPolicy refreshPolicy, + String workflowId, + Method method, + ActionListener listener + ) { + if (method != Method.POST && method != Method.PUT) { + log.error(String.format("Method %s not supported when upserting the workflow", method.name())); + listener.onFailure(SecurityAnalyticsException.wrap(new OpenSearchException("Method not supported"))); + return; + } + + List monitorIds = new ArrayList<>(); + monitorIds.addAll(addedMonitors); + + if (updatedMonitors != null && !updatedMonitors.isEmpty()) { + monitorIds.addAll(updatedMonitors); + } + + IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds, + detector, + refreshPolicy, workflowId, method); + + AlertingPluginInterface.INSTANCE.indexWorkflow((NodeClient) client, + indexWorkflowRequest, + new ActionListener<>() { + @Override + public void onResponse(IndexWorkflowResponse workflowResponse) { + listener.onResponse(workflowResponse); + } + + @Override + public void onFailure(Exception e) { + // Remove created monitors and fail creation of workflow + log.error("Failed workflow saving. Removing created monitors: " + addedMonitors.stream().collect( + Collectors.joining()) , e); + + monitorService.deleteAlertingMonitors(addedMonitors, + refreshPolicy, + new ActionListener<>() { + @Override + public void onResponse(List deleteMonitorResponses) { + log.debug("Monitors successfully deleted"); + listener.onFailure(e); + } + + @Override + public void onFailure(Exception e) { + log.error("Error deleting monitors", e); + listener.onFailure(e); + } + }); + } + }); + } + + public void deleteWorkflow(String workflowId, ActionListener deleteWorkflowListener) { + DeleteWorkflowRequest deleteWorkflowRequest = new DeleteWorkflowRequest(workflowId, false); + AlertingPluginInterface.INSTANCE.deleteWorkflow((NodeClient) client, deleteWorkflowRequest, deleteWorkflowListener); + } + + private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method) { + AtomicInteger index = new AtomicInteger(); + + // TODO - update chained findings + List delegates = monitorIds.stream().map( + monitorId -> new Delegate(index.incrementAndGet(), monitorId, null) + ).collect(Collectors.toList()); + + Sequence sequence = new Sequence(delegates); + CompositeInput compositeInput = new CompositeInput(sequence); + + Workflow workflow = new Workflow( + workflowId, + Workflow.NO_VERSION, + detector.getName(), + detector.getEnabled(), + detector.getSchedule(), + detector.getLastUpdateTime(), + detector.getEnabledTime(), + WorkflowType.COMPOSITE, + detector.getUser(), + 1, + List.of(compositeInput), + "security_analytics", + Collections.emptyList(), + false + ); + + return new IndexWorkflowRequest( + workflowId, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + refreshPolicy, + method, + workflow, + null + ); + } + + private Map mapMonitorIds(List monitorResponses) { + return monitorResponses.stream().collect( + Collectors.toMap( + // In the case of bucket level monitors rule id is trigger id + it -> { + if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { + return it.getMonitor().getTriggers().get(0).getId(); + } else { + return Detector.DOC_LEVEL_MONITOR; + } + }, + IndexMonitorResponse::getId + ) + ); + } +} + diff --git a/src/main/resources/mappings/detectors.json b/src/main/resources/mappings/detectors.json index 776ed1d39..e1e160d5f 100644 --- a/src/main/resources/mappings/detectors.json +++ b/src/main/resources/mappings/detectors.json @@ -1,6 +1,6 @@ { "_meta" : { - "schema_version": 1 + "schema_version": 2 }, "properties": { "detector": { @@ -88,6 +88,9 @@ } } }, + "workflow_ids": { + "type": "keyword" + }, "rule_index": { "type": "text", "fields": { diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index c47bc148b..316730163 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -132,6 +132,66 @@ protected void createRuleTopicIndex(String detectorType, String additionalMappin } } + protected void verifyWorkflow(Map detectorMap, List monitorIds, int expectedDelegatesNum) throws IOException{ + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Map workflow = searchWorkflow(workflowId); + assertNotNull("Workflow not found", workflow); + + List> workflowInputs = (List>) workflow.get("inputs"); + assertEquals("Workflow not found", 1, workflowInputs.size()); + + Map sequence = ((Map)((Map)workflowInputs.get(0).get("composite_input")).get("sequence")); + assertNotNull("Sequence is null", sequence); + + List> delegates = (List>) sequence.get("delegates"); + assertEquals(expectedDelegatesNum, delegates.size()); + // Assert that all monitors are present + for (Map delegate: delegates) { + assertTrue("Monitor doesn't exist in monitor list", monitorIds.contains(delegate.get("monitor_id"))); + } + } + + protected Map searchWorkflow(String workflowId) throws IOException{ + String workflowRequest = "{\n" + + " \"query\":{\n" + + " \"term\":{\n" + + " \"_id\":{\n" + + " \"value\":\"" + workflowId + "\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(ScheduledJob.SCHEDULED_JOBS_INDEX, workflowRequest); + if (hits.size() == 0) { + return new HashMap<>(); + } + + SearchHit hit = hits.get(0); + return (Map) hit.getSourceAsMap().get("workflow"); + } + + + protected List> getAllWorkflows() throws IOException{ + String workflowRequest = "{\n" + + " \"query\":{\n" + + " \"exists\":{\n" + + " \"field\": \"workflow\"" + + " }\n" + + " }\n" + + " }"; + + List hits = executeSearch(ScheduledJob.SCHEDULED_JOBS_INDEX, workflowRequest); + if (hits.size() == 0) { + return new ArrayList<>(); + } + List> result = new ArrayList<>(); + for (SearchHit hit: hits) { + result.add((Map) hit.getSourceAsMap().get("workflow")); + } + return result; + } + protected String createDetector(Detector detector) throws IOException { Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -336,6 +396,14 @@ protected Response deleteAlertingMonitor(RestClient client, String monitorId) th return makeRequest(client, "DELETE", String.format(Locale.getDefault(), "/_plugins/_alerting/monitors/%s", monitorId), new HashMap<>(), null); } + protected Response executeAlertingWorkflow(String monitorId, Map params) throws IOException { + return executeAlertingWorkflow(client(), monitorId, params); + } + + protected Response executeAlertingWorkflow(RestClient client, String workflowId, Map params) throws IOException { + return makeRequest(client, "POST", String.format(Locale.getDefault(), "/_plugins/_alerting/workflows/%s/_execute", workflowId), params, null); + } + protected List executeSearch(String index, String request) throws IOException { return executeSearch(index, request, true); } @@ -1655,7 +1723,6 @@ protected void createSampleDatastream(String datastreamName, String mappings, bo createDatastreamAPI(datastreamName); } - protected void restoreAlertsFindingsIMSettings() throws IOException { updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "720m"); updateClusterSetting(ALERT_HISTORY_MAX_DOCS.getKey(), "100000"); @@ -1668,4 +1735,11 @@ protected void restoreAlertsFindingsIMSettings() throws IOException { updateClusterSetting(FINDING_HISTORY_RETENTION_PERIOD.getKey(), "60d"); } + + protected void enableOrDisableWorkflow(String trueOrFalse) throws IOException { + Request request = new Request("PUT", "_cluster/settings"); + String entity = "{\"persistent\":{\"plugins.security_analytics.filter_by_backend_roles\" : " + trueOrFalse + "}}"; + request.setJsonEntity(entity); + client().performRequest(request); + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index b8015e05c..6033d4084 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -143,7 +143,7 @@ public static Detector randomDetector(String name, DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList()); } public static CustomLogType randomCustomLogType(String name, String description, String source) { @@ -168,7 +168,28 @@ public static Detector randomDetectorWithNoUser() { Instant enabledTime = enabled ? Instant.now().truncatedTo(ChronoUnit.MILLIS) : null; Instant lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.MILLIS); - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, null, inputs, Collections.emptyList(),Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); + return new Detector( + null, + null, + name, + enabled, + schedule, + lastUpdateTime, + enabledTime, + detectorType, + null, + inputs, + Collections.emptyList(), + Collections.singletonList(""), + "", + "", + "", + "", + "", + "", + Collections.emptyMap(), + Collections.emptyList() + ); } public static CorrelationRule randomCorrelationRule(String name) { diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index d47650411..db366056b 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -49,7 +49,8 @@ public void testIndexDetectorPostResponse() throws IOException { null, null, DetectorMonitorConfig.getFindingsIndex("others_application"), - Collections.emptyMap() + Collections.emptyMap(), + Collections.emptyList() ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index ae3ecbeda..78dacd6e1 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -64,7 +64,8 @@ public void testGetAlerts_success() { null, null, DetectorMonitorConfig.getFindingsIndex("others_application"), - Collections.emptyMap() + Collections.emptyMap(), + Collections.emptyList() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -240,7 +241,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { null, null, DetectorMonitorConfig.getFindingsIndex("others_application"), - Collections.emptyMap() + Collections.emptyMap(), + Collections.emptyList() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index 232c9a221..0fb9376b6 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -64,7 +64,8 @@ public void testGetFindings_success() { null, null, DetectorMonitorConfig.getFindingsIndex("others_application"), - Collections.emptyMap() + Collections.emptyMap(), + Collections.emptyList() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -184,7 +185,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { null, null, DetectorMonitorConfig.getFindingsIndex("others_application"), - Collections.emptyMap() + Collections.emptyMap(), + Collections.emptyList() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index fb110ed50..8358ae87e 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -12,6 +12,7 @@ import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.randomRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; import java.io.IOException; import java.util.ArrayList; @@ -855,6 +856,8 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + " \"query\" : {\n" + " \"match_all\":{\n" + @@ -974,12 +977,473 @@ else if (ruleId == minRuleId) { assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding)); } - private static void assertRuleMonitorFinding(Map executeResults, String ruleId, int expectedDocCount, List expectedTriggerResult) { + public void testCreateDetector_verifyWorkflowCreation_success() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(2, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 2); + } + + public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_success() throws IOException { + // By default, workflow usage is disabled - disabling it just in any case + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "false"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String randomDocRuleId = createRule(randomRule()); + + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertTrue("Workflow created", ((List) detectorMap.get("workflow_ids")).size() == 0); + List workflows = getAllWorkflows(); + assertTrue("Workflow created", workflows.size() == 0); + + // Enable workflow usage and verify detector update + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + var updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); + + assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + + // Verify that the workflow for the given detector is not added + assertTrue("Workflow created", ((List) detectorMap.get("workflow_ids")).size() == 0); + workflows = getAllWorkflows(); + assertTrue("Workflow created", workflows.size() == 0); + } + + public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String randomDocRuleId = createRule(randomRule()); + + List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(2, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 2); + + // Update detector - remove one agg rule; Verify workflow + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), Arrays.asList(new DetectorRule(randomDocRuleId)) , getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + detector = randomDetectorWithInputs(List.of(newInput)); + createResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); + + assertEquals("Update detector failed", RestStatus.OK, restStatus(createResponse)); + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + inputArr = (List) detectorMap.get("inputs"); + + assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 1); + + indexDoc(index, "1", randomDoc(5, 3, testOpCode)); + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + List> monitorRunResults = (List>) entityAsMap(executeResponse).get("monitor_run_results"); + assertEquals(1, monitorRunResults.size()); + + int noOfSigmaRuleMatches = ((List>) ((Map) monitorRunResults.get(0).get("input_results")).get("results")).get(0).size(); + assertEquals(6, noOfSigmaRuleMatches); + + // Verify findings + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + assertEquals(1, getFindingsBody.get("total_findings")); + + String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + assertEquals(detectorId, findingDetectorId); + + String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + assertEquals(index, findingIndex); + + List> findings = (List)getFindingsBody.get("findings"); + + assertEquals(1, findings.size()); + List findingDocs = (List) findings.get(0).get("related_doc_ids"); + Assert.assertEquals(1, findingDocs.size()); + assertTrue(Arrays.asList("1").containsAll(findingDocs)); + } + + public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String randomDocRuleId = createRule(randomRule()); + + List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(2, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 2); + } + + public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitors_success() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String randomDocRuleId = createRule(randomRule()); + + List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(2, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + indexDoc(index, "1", randomDoc(5, 3, testOpCode)); + indexDoc(index, "2", randomDoc(2, 3, testOpCode)); + indexDoc(index, "3", randomDoc(4, 3, testOpCode)); + indexDoc(index, "4", randomDoc(6, 2, testOpCode)); + indexDoc(index, "5", randomDoc(1, 1, testOpCode)); + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 2); + + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + Map executeWorkflowResponseMap = entityAsMap(executeResponse); + List> monitorRunResults = (List>) executeWorkflowResponseMap.get("monitor_run_results"); + + for (Map runResult : monitorRunResults) { + if (((Map) runResult.get("trigger_results")).get(maxRuleId) != null) { + assertRuleMonitorFinding(runResult, maxRuleId, 5, List.of("2", "3")); + } else { + int noOfSigmaRuleMatches = ((List>) ((Map) runResult.get("input_results")).get("results")).get(0).size(); + // 5 prepackaged and 1 custom doc level rule + assertEquals(1, noOfSigmaRuleMatches); + } + } + + // Verify findings + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + assertEquals(6, getFindingsBody.get("total_findings")); + + String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + assertEquals(detectorId, findingDetectorId); + + String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + assertEquals(index, findingIndex); + + List docLevelFinding = new ArrayList<>(); + List> findings = (List)getFindingsBody.get("findings"); + + Set docLevelRules = new HashSet<>(List.of(randomDocRuleId)); + + for(Map finding : findings) { + List> queries = (List>) finding.get("queries"); + Set findingRules = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); + // In this test case all doc level rules are matching the finding rule ids + if(docLevelRules.containsAll(findingRules)) { + docLevelFinding.addAll((List) finding.get("related_doc_ids")); + } else { + List findingDocs = (List) finding.get("related_doc_ids"); + Assert.assertEquals(4, findingDocs.size()); + assertTrue(Arrays.asList("1", "2", "3", "4").containsAll(findingDocs)); + } + } + // Verify doc level finding + assertTrue(Arrays.asList("1", "2", "3", "4", "5").containsAll(docLevelFinding)); + } + + + private static void assertRuleMonitorFinding(Map executeResults, String ruleId, int expectedDocCount, List expectedTriggerResult) { List> buckets = ((List>)(((Map)((Map)((Map)((List)((Map) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get("result_agg")).get("buckets"))); Integer docCount = buckets.stream().mapToInt(it -> (Integer)it.get("doc_count")).sum(); assertEquals(expectedDocCount, docCount.intValue()); List triggerResultBucketKeys = ((Map)((Map) ((Map)executeResults.get("trigger_results")).get(ruleId)).get("agg_result_buckets")).keySet().stream().collect(Collectors.toList()); - assertEquals(expectedTriggerResult, triggerResultBucketKeys); + Assert.assertEquals(expectedTriggerResult, triggerResultBucketKeys); } } diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index d32632d77..b164c71f9 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -39,6 +39,7 @@ import org.opensearch.securityanalytics.model.DetectorTrigger; import static org.opensearch.securityanalytics.TestHelpers.*; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; public class DetectorRestApiIT extends SecurityAnalyticsRestTestCase { @@ -809,6 +810,116 @@ public void testDeletingADetector_single_ruleTopicIndex() throws IOException { Assert.assertEquals(0, hits.size()); } + + public void testDeletingADetector_single_Monitor() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + // Create detector #1 of type test_windows + Detector detector1 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + String detectorId1 = createDetector(detector1); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId1 + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + Map responseBody = hit.getSourceAsMap(); + Map detectorResponse1 = (Map) responseBody.get("detector"); + + indexDoc(index, "1", randomDoc()); + String monitorId = ((List) (detectorResponse1).get("monitor_id")).get(0); + + verifyWorkflow(detectorResponse1, Arrays.asList(monitorId), 1); + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(5, noOfSigmaRuleMatches); + // Create detector #2 of type windows + Detector detector2 = randomDetectorWithTriggers(getRandomPrePackagedRules(), List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of()))); + String detectorId2 = createDetector(detector2); + + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId2 + "\"\n" + + " }\n" + + " }\n" + + "}"; + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + + responseBody = hit.getSourceAsMap(); + Map detectorResponse2 = (Map) responseBody.get("detector"); + monitorId = ((List) (detectorResponse2).get("monitor_id")).get(0); + + verifyWorkflow(detectorResponse2, Arrays.asList(monitorId), 1); + + indexDoc(index, "2", randomDoc()); + + executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(5, noOfSigmaRuleMatches); + + Response deleteResponse = makeRequest(client(), "DELETE", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId1, Collections.emptyMap(), null); + Assert.assertEquals("Delete detector failed", RestStatus.OK, restStatus(deleteResponse)); + + String workflowId1 = ((List) detectorResponse1.get("workflow_ids")).get(0); + + Map workflow1 = searchWorkflow(workflowId1); + assertEquals("Workflow " + workflowId1 + " not deleted", Collections.emptyMap(), workflow1); + + deleteResponse = makeRequest(client(), "DELETE", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId2, Collections.emptyMap(), null); + Assert.assertEquals("Delete detector failed", RestStatus.OK, restStatus(deleteResponse)); + + String workflowId2 = ((List) detectorResponse2.get("workflow_ids")).get(0); + Map workflow2 = searchWorkflow(workflowId2); + assertEquals("Workflow " + workflowId2 + " not deleted", Collections.emptyMap(), workflow2); + + // We deleted all detectors of type windows, so we expect that queryIndex is deleted + Assert.assertFalse(doesIndexExist(String.format(Locale.ROOT, ".opensearch-sap-%s-detectors-queries-000001", "test_windows"))); + + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId1 + "\"\n" + + " }\n" + + " }\n" + + "}"; + hits = executeSearch(Detector.DETECTORS_INDEX, request); + Assert.assertEquals(0, hits.size()); + + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId2 + "\"\n" + + " }\n" + + " }\n" + + "}"; + hits = executeSearch(Detector.DETECTORS_INDEX, request); + Assert.assertEquals(0, hits.size()); + } + public void testDeletingADetector_oneDetectorType_multiple_ruleTopicIndex() throws IOException { String index1 = "test_index_1"; createIndex(index1, Settings.EMPTY);