From 655064b6b03aac054d8ffd8f71c85b8984fb8735 Mon Sep 17 00:00:00 2001 From: Surya Sashank Nistala Date: Thu, 7 Sep 2023 16:14:24 -0700 Subject: [PATCH] support feeding findings to chained finding monitors ONLY from rules mentioned in detector triggers Signed-off-by: Surya Sashank Nistala --- .../model/DetectorTrigger.java | 16 ++ .../TransportIndexDetectorAction.java | 32 ++- .../securityanalytics/util/DetectorUtils.java | 37 +++ .../util/WorkflowService.java | 33 +-- .../SecurityAnalyticsRestTestCase.java | 1 - .../securityanalytics/TestHelpers.java | 7 +- .../resthandler/DetectorMonitorRestApiIT.java | 215 ++++++++++++++++-- 7 files changed, 282 insertions(+), 59 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java index f4cdd6f06..b74a71048 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java +++ b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java @@ -309,6 +309,22 @@ public String getSeverity() { return severity; } + public List getRuleTypes() { + return ruleTypes; + } + + public List getRuleIds() { + return ruleIds; + } + + public List getRuleSeverityLevels() { + return ruleSeverityLevels; + } + + public List getTags() { + return tags; + } + public List getActions() { List transformedActions = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 08a00c86e..c5ac8611c 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -96,6 +96,7 @@ import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.util.DetectorIndices; +import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.IndexUtils; import org.opensearch.securityanalytics.util.MonitorService; import org.opensearch.securityanalytics.util.RuleIndices; @@ -155,7 +156,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction> rulesById, Detect StepListener> indexMonitorsStep = new StepListener<>(); indexMonitorsStep.whenComplete( - indexMonitorResponses -> saveWorkflow(detector, indexMonitorResponses, refreshPolicy, listener), + indexMonitorResponses -> saveWorkflow(rulesById, detector, indexMonitorResponses, refreshPolicy, listener), e -> { log.error("Failed to index the workflow", e); listener.onFailure(e); @@ -283,7 +284,7 @@ private void createMonitorFromQueries(List> rulesById, Detect int numberOfUnprocessedResponses = monitorRequests.size() - 1; if (numberOfUnprocessedResponses == 0) { - saveWorkflow(detector, monitorResponses, refreshPolicy, listener); + saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener); } else { // Saves the rest of the monitors and saves the workflow if supported saveMonitors( @@ -312,7 +313,7 @@ private void createMonitorFromQueries(List> rulesById, Detect AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, indexDocLevelMonitorStep); indexDocLevelMonitorStep.whenComplete(addedFirstMonitorResponse -> { monitorResponses.add(addedFirstMonitorResponse); - saveWorkflow(detector, monitorResponses, refreshPolicy, listener); + saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener); }, listener::onFailure ); @@ -346,19 +347,22 @@ public void onFailure(Exception e) { /** * If the workflow is enabled, saves the workflow, updates the detector and returns the saved monitors * if not, returns the saved monitors + * + * @param rulesById * @param detector * @param monitorResponses * @param refreshPolicy * @param actionListener */ private void saveWorkflow( - Detector detector, - List monitorResponses, - RefreshPolicy refreshPolicy, - ActionListener> actionListener + List> rulesById, Detector detector, + List monitorResponses, + RefreshPolicy refreshPolicy, + ActionListener> actionListener ) { if (enabledWorkflowUsage) { workflowService.upsertWorkflow( + rulesById, monitorResponses, null, detector, @@ -446,7 +450,7 @@ public void onResponse(Map> ruleFieldMappings) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } catch (IOException | SigmaError ex) { listener.onFailure(ex); } @@ -474,7 +478,7 @@ public void onFailure(Exception e) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } } @@ -493,6 +497,7 @@ public void onFailure(Exception e) { * @param listener Listener that accepts the list of updated monitors if the action was successful */ private void updateAlertingMonitors( + List> rulesById, Detector detector, List monitorsToBeAdded, List monitorsToBeUpdated, @@ -519,6 +524,7 @@ private void updateAlertingMonitors( } if (detector.isWorkflowSupported() && enabledWorkflowUsage) { updateWorkflowStep( + rulesById, detector, monitorsToBeDeleted, refreshPolicy, @@ -560,6 +566,7 @@ public void onFailure(Exception e) { } private void updateWorkflowStep( + List> rulesById, Detector detector, List monitorsToBeDeleted, RefreshPolicy refreshPolicy, @@ -596,6 +603,7 @@ public void onFailure(Exception e) { } else { // Update workflow and delete the monitors workflowService.upsertWorkflow( + rulesById, addNewMonitorsResponse, updateMonitorResponse, detector, @@ -749,8 +757,8 @@ public void onResponse(Map> ruleFieldMappings) { queryBackendMap.get(rule.getCategory()))); } } - // if workflow usage enabled, add chained findings monitor request since there are bucket level requests - if(enabledWorkflowUsage && false == monitorRequests.isEmpty()) { + // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger + if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) { monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId()+"_chained_findings", Method.POST)); } listener.onResponse(monitorRequests); diff --git a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java index 5e9d25c38..28e316e06 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java +++ b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java @@ -4,8 +4,11 @@ */ package org.opensearch.securityanalytics.util; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TotalHits; import org.opensearch.cluster.routing.Preference; +import org.opensearch.commons.alerting.action.IndexMonitorResponse; +import org.opensearch.commons.alerting.model.Monitor; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -25,6 +28,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.Rule; import java.io.IOException; import java.util.Collections; @@ -32,6 +36,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; public class DetectorUtils { @@ -95,4 +100,36 @@ public void onFailure(Exception e) { } }); } + + public static List getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger( + Detector detector, + List> rulesById, + List monitorResponses + ) { + List aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById); + return monitorResponses.stream().filter( + // In the case of bucket level monitors rule id is trigger id + it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType() + && !it.getMonitor().getTriggers().isEmpty() + && aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId()) + ).map(IndexMonitorResponse::getId).collect(Collectors.toList()); + } + public static List getAggRuleIdsConfiguredToTrigger(Detector detector, List> rulesById) { + Set ruleIdsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getRuleIds().stream()).collect(Collectors.toSet()); + Set tagsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getTags().stream()).collect(Collectors.toSet()); + return rulesById.stream() + .filter(it -> checkIfRuleIsAggAndTriggerable( it.getRight(), ruleIdsConfiguredToTrigger, tagsConfiguredToTrigger)) + .map(stringRulePair -> stringRulePair.getRight().getId()) + .collect(Collectors.toList()); + } + + private static boolean checkIfRuleIsAggAndTriggerable(Rule rule, Set ruleIdsConfiguredToTrigger, Set tagsConfiguredToTrigger) { + if (rule.isAggregationRule()) { + return ruleIdsConfiguredToTrigger.contains(rule.getId()) + || rule.getTags().stream().anyMatch(tag -> tagsConfiguredToTrigger.contains(tag.getValue())); + } + return false; + } + + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java index 21a0013c7..5ce495b98 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java +++ b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.util; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -28,6 +29,7 @@ import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.rest.RestRequest.Method; import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.Rule; import java.util.ArrayList; import java.util.Collections; @@ -37,6 +39,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger; + /** * Alerting common clas used for workflow manipulation */ @@ -67,6 +71,7 @@ public WorkflowService(Client client, MonitorService monitorService) { * @param listener */ public void upsertWorkflow( + List> rulesById, List addedMonitorResponses, List updatedMonitorResponses, Detector detector, @@ -90,13 +95,13 @@ public void upsertWorkflow( } ChainedMonitorFindings chainedMonitorFindings = null; String cmfMonitorId = null; - if(addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) { - List bucketMonitorIds = addedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList()); - if(!updatedMonitors.isEmpty()) { - bucketMonitorIds.addAll(updatedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList())); + if (addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) { + List monitorResponses = new ArrayList<>(addedMonitorResponses); + if (updatedMonitorResponses != null) { + monitorResponses.addAll(updatedMonitorResponses); } cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId(); - chainedMonitorFindings = new ChainedMonitorFindings(null, bucketMonitorIds); + chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses)); } IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds, @@ -154,7 +159,7 @@ private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Dete return delegate; } ).collect(Collectors.toList()); - + Sequence sequence = new Sequence(delegates); CompositeInput compositeInput = new CompositeInput(sequence); @@ -185,21 +190,5 @@ private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Dete null ); } - - private Map mapMonitorIds(List monitorResponses) { - return monitorResponses.stream().collect( - Collectors.toMap( - // In the case of bucket level monitors rule id is trigger id - it -> { - if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { - return it.getMonitor().getTriggers().get(0).getId(); - } else { - return Detector.DOC_LEVEL_MONITOR; - } - }, - IndexMonitorResponse::getId - ) - ); - } } diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index 5f03b4e5d..f8ed33062 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -127,7 +127,6 @@ protected void createRuleTopicIndex(String detectorType, String additionalMappin assertEquals(RestStatus.OK, restStatus(response)); } } - protected void verifyWorkflow(Map detectorMap, List monitorIds, int expectedDelegatesNum) throws IOException{ String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index b98a6e641..ca6da3609 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -65,6 +65,10 @@ public static Detector randomDetector(List rules, String detectorType) { public static Detector randomDetectorWithInputs(List inputs) { return randomDetector(null, null, null, inputs, List.of(), null, null, null, null); } + + public static Detector randomDetectorWithInputsAndTriggers(List inputs, List triggers) { + return randomDetector(null, null, null, inputs, List.of(), null, null, null, null); + } public static Detector randomDetectorWithInputs(List inputs, String detectorType) { return randomDetector(null, detectorType, null, inputs, List.of(), null, null, null, null); } @@ -84,9 +88,6 @@ public static Detector randomDetectorWithTriggers(List rules, List inputs, List triggers) { - return randomDetector(null, null, null, inputs, triggers, null, null, null, null); - } public static Detector randomDetectorWithTriggers(List rules, List triggers, String detectorType, DetectorInput input) { return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index 95d8ff6cb..53f56afa6 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -4,27 +4,6 @@ */ package org.opensearch.securityanalytics.resthandler; -import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; -import static org.opensearch.securityanalytics.TestHelpers.randomDetector; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; -import static org.opensearch.securityanalytics.TestHelpers.randomDoc; -import static org.opensearch.securityanalytics.TestHelpers.randomIndex; -import static org.opensearch.securityanalytics.TestHelpers.randomRule; -import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; -import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - import org.apache.hc.core5.http.HttpStatus; import org.junit.Assert; import org.opensearch.action.search.SearchResponse; @@ -39,8 +18,31 @@ import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; +import org.opensearch.securityanalytics.model.DetectorTrigger; import org.opensearch.securityanalytics.model.Rule; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; +import static org.opensearch.securityanalytics.TestHelpers.randomDetector; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; +import static org.opensearch.securityanalytics.TestHelpers.randomDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomRule; +import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; + public class DetectorMonitorRestApiIT extends SecurityAnalyticsRestTestCase { /** * 1. Creates detector with 5 doc prepackaged level rules and one doc level monitor based on the given rules @@ -979,6 +981,177 @@ else if (ruleId == minRuleId) { assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding)); } + /** + * 1. Creates detector with aggregation and prepackaged rules + * (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7) + * 2. Verifies monitor execution + * 3. Verifies findings + * + * @throws IOException + */ + public void testMultipleAggregationAndDocRules_findingSuccessWithBucketLevelTriggersOnRuleIds() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String infoOpCode = "Info"; + String testOpCode = "Test"; + + // 5 custom aggregation rules + String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); + String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); + String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); + List aggRuleIds = List.of(sumRuleId, maxRuleId); + String randomDocRuleId = createRule(randomRule()); + List prepackagedRules = getRandomPrePackagedRules(); + + List detectorRules = List.of(new DetectorRule(sumRuleId), new DetectorRule(maxRuleId), new DetectorRule(minRuleId), + new DetectorRule(avgRuleId), new DetectorRule(cntRuleId), new DetectorRule(randomDocRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + DetectorTrigger t1, t2; + t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(sumRuleId, maxRuleId), List.of(), List.of(), List.of()); + t2 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(minRuleId, avgRuleId, cntRuleId), List.of(), List.of(), List.of()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1, t2)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(7, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = updatedDetectorMap.get("inputs"); + + assertEquals(6, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (updatedDetectorMap).get("monitor_id")); + + assertEquals(7, monitorIds.size()); + + indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); + indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); + indexDoc(index, "3", randomDoc(1, 4, infoOpCode)); + indexDoc(index, "4", randomDoc(5, 3, testOpCode)); + indexDoc(index, "5", randomDoc(2, 3, testOpCode)); + indexDoc(index, "6", randomDoc(4, 3, testOpCode)); + indexDoc(index, "7", randomDoc(6, 2, testOpCode)); + indexDoc(index, "8", randomDoc(1, 1, testOpCode)); + + Map numberOfMonitorTypes = new HashMap<>(); + + for (String monitorId: monitorIds) { + Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + + // Assert monitor executions + Map executeResults = entityAsMap(executeResponse); + if (MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) { + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + // 5 prepackaged and 1 custom doc level rule + assertEquals(6, noOfSigmaRuleMatches); + } else if (MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type"))) { + for(String ruleId: aggRuleIds) { + Object rule = (((Map)((Map)((List)((Map)executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get(ruleId)); + if(rule != null) { + if(ruleId == sumRuleId) { + assertRuleMonitorFinding(executeResults, ruleId,3, List.of("4")); + } else if (ruleId == maxRuleId) { + assertRuleMonitorFinding(executeResults, ruleId,5, List.of("2", "3")); + } + else if (ruleId == minRuleId) { + assertRuleMonitorFinding(executeResults, ruleId,1, List.of("2")); + } + } + } + } + } + + assertEquals(5, numberOfMonitorTypes.get(MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue()); + assertEquals(2, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + // Assert findings + assertNotNull(getFindingsBody); + // 8 findings from doc level rules, and 3 findings for aggregation (sum, max and min) + assertEquals(19, getFindingsBody.get("total_findings")); + + String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + assertEquals(detectorId, findingDetectorId); + + String findingIndex = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("index").toString(); + assertEquals(index, findingIndex); + + List docLevelFinding = new ArrayList<>(); + List> findings = (List) getFindingsBody.get("findings"); + + Set docLevelRules = new HashSet<>(prepackagedRules); + docLevelRules.add(randomDocRuleId); + + for(Map finding : findings) { + List> queries = (List>)finding.get("queries"); + Set findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); + // Doc level finding matches all doc level rules (including the custom one) in this test case + if(docLevelRules.containsAll(findingRuleIds)) { + docLevelFinding.addAll((List)finding.get("related_doc_ids")); + } else { + // In the case of bucket level monitors, queries will always contain one value + String aggRuleId = findingRuleIds.iterator().next(); + List findingDocs = (List)finding.get("related_doc_ids"); + + if(aggRuleId.equals(sumRuleId)) { + assertTrue(List.of("1", "2", "3").containsAll(findingDocs)); + } else if(aggRuleId.equals(maxRuleId)) { + assertTrue(List.of("4", "5", "6", "7").containsAll(findingDocs)); + } else if(aggRuleId.equals( minRuleId)) { + assertTrue(List.of("7").containsAll(findingDocs)); + } + } + } + + assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding)); + } + public void testCreateDetector_verifyWorkflowCreation_success() throws IOException { updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); String index = createTestIndex(randomIndex(), windowsIndexMapping());