Skip to content

Commit

Permalink
support feeding findings to chained finding monitors ONLY from rules …
Browse files Browse the repository at this point in the history
…mentioned in detector triggers

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>
  • Loading branch information
eirsep committed Sep 7, 2023
1 parent ac2323a commit 655064b
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ public String getSeverity() {
return severity;
}

public List<String> getRuleTypes() {
return ruleTypes;
}

public List<String> getRuleIds() {
return ruleIds;
}

public List<String> getRuleSeverityLevels() {
return ruleSeverityLevels;
}

public List<String> getTags() {
return tags;
}

public List<Action> getActions() {
List<Action> transformedActions = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
import org.opensearch.securityanalytics.util.RuleIndices;
Expand Down Expand Up @@ -155,7 +156,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction<IndexDe
private final MonitorService monitorService;
private final IndexNameExpressionResolver indexNameExpressionResolver;

private volatile TimeValue indexTimeout;
private final TimeValue indexTimeout;
@Inject
public TransportIndexDetectorAction(TransportService transportService,
Client client,
Expand Down Expand Up @@ -275,15 +276,15 @@ private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detect

StepListener<List<IndexMonitorResponse>> indexMonitorsStep = new StepListener<>();
indexMonitorsStep.whenComplete(
indexMonitorResponses -> saveWorkflow(detector, indexMonitorResponses, refreshPolicy, listener),
indexMonitorResponses -> saveWorkflow(rulesById, detector, indexMonitorResponses, refreshPolicy, listener),
e -> {
log.error("Failed to index the workflow", e);
listener.onFailure(e);
});

int numberOfUnprocessedResponses = monitorRequests.size() - 1;
if (numberOfUnprocessedResponses == 0) {
saveWorkflow(detector, monitorResponses, refreshPolicy, listener);
saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener);
} else {
// Saves the rest of the monitors and saves the workflow if supported
saveMonitors(
Expand Down Expand Up @@ -312,7 +313,7 @@ private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detect
AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, indexDocLevelMonitorStep);
indexDocLevelMonitorStep.whenComplete(addedFirstMonitorResponse -> {
monitorResponses.add(addedFirstMonitorResponse);
saveWorkflow(detector, monitorResponses, refreshPolicy, listener);
saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener);
},
listener::onFailure
);
Expand Down Expand Up @@ -346,19 +347,22 @@ public void onFailure(Exception e) {
/**
* If the workflow is enabled, saves the workflow, updates the detector and returns the saved monitors
* if not, returns the saved monitors
*
* @param rulesById
* @param detector
* @param monitorResponses
* @param refreshPolicy
* @param actionListener
*/
private void saveWorkflow(
Detector detector,
List<IndexMonitorResponse> monitorResponses,
RefreshPolicy refreshPolicy,
ActionListener<List<IndexMonitorResponse>> actionListener
List<Pair<String, Rule>> rulesById, Detector detector,
List<IndexMonitorResponse> monitorResponses,
RefreshPolicy refreshPolicy,
ActionListener<List<IndexMonitorResponse>> actionListener
) {
if (enabledWorkflowUsage) {
workflowService.upsertWorkflow(
rulesById,
monitorResponses,
null,
detector,
Expand Down Expand Up @@ -446,7 +450,7 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect(
Collectors.toList()));

updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
} catch (IOException | SigmaError ex) {
listener.onFailure(ex);
}
Expand Down Expand Up @@ -474,7 +478,7 @@ public void onFailure(Exception e) {
monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect(
Collectors.toList()));

updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
}
}

Expand All @@ -493,6 +497,7 @@ public void onFailure(Exception e) {
* @param listener Listener that accepts the list of updated monitors if the action was successful
*/
private void updateAlertingMonitors(
List<Pair<String, Rule>> rulesById,
Detector detector,
List<IndexMonitorRequest> monitorsToBeAdded,
List<IndexMonitorRequest> monitorsToBeUpdated,
Expand All @@ -519,6 +524,7 @@ private void updateAlertingMonitors(
}
if (detector.isWorkflowSupported() && enabledWorkflowUsage) {
updateWorkflowStep(
rulesById,
detector,
monitorsToBeDeleted,
refreshPolicy,
Expand Down Expand Up @@ -560,6 +566,7 @@ public void onFailure(Exception e) {
}

private void updateWorkflowStep(
List<Pair<String, Rule>> rulesById,
Detector detector,
List<String> monitorsToBeDeleted,
RefreshPolicy refreshPolicy,
Expand Down Expand Up @@ -596,6 +603,7 @@ public void onFailure(Exception e) {
} else {
// Update workflow and delete the monitors
workflowService.upsertWorkflow(
rulesById,
addNewMonitorsResponse,
updateMonitorResponse,
detector,
Expand Down Expand Up @@ -749,8 +757,8 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
queryBackendMap.get(rule.getCategory())));
}
}
// if workflow usage enabled, add chained findings monitor request since there are bucket level requests
if(enabledWorkflowUsage && false == monitorRequests.isEmpty()) {
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId()+"_chained_findings", Method.POST));
}
listener.onResponse(monitorRequests);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
*/
package org.opensearch.securityanalytics.util;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.TotalHits;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.commons.alerting.action.IndexMonitorResponse;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -25,13 +28,15 @@
import org.opensearch.search.suggest.Suggest;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.model.DetectorInput;
import org.opensearch.securityanalytics.model.Rule;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class DetectorUtils {

Expand Down Expand Up @@ -95,4 +100,36 @@ public void onFailure(Exception e) {
}
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Set<String> ruleIdsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getRuleIds().stream()).collect(Collectors.toSet());
Set<String> tagsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getTags().stream()).collect(Collectors.toSet());
return rulesById.stream()
.filter(it -> checkIfRuleIsAggAndTriggerable( it.getRight(), ruleIdsConfiguredToTrigger, tagsConfiguredToTrigger))
.map(stringRulePair -> stringRulePair.getRight().getId())
.collect(Collectors.toList());
}

private static boolean checkIfRuleIsAggAndTriggerable(Rule rule, Set<String> ruleIdsConfiguredToTrigger, Set<String> tagsConfiguredToTrigger) {
if (rule.isAggregationRule()) {
return ruleIdsConfiguredToTrigger.contains(rule.getId())
|| rule.getTags().stream().anyMatch(tag -> tagsConfiguredToTrigger.contains(tag.getValue()));
}
return false;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.securityanalytics.util;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand All @@ -28,6 +29,7 @@
import org.opensearch.index.seqno.SequenceNumbers;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.model.Rule;

import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -37,6 +39,8 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger;

/**
* Alerting common clas used for workflow manipulation
*/
Expand Down Expand Up @@ -67,6 +71,7 @@ public WorkflowService(Client client, MonitorService monitorService) {
* @param listener
*/
public void upsertWorkflow(
List<Pair<String, Rule>> rulesById,
List<IndexMonitorResponse> addedMonitorResponses,
List<IndexMonitorResponse> updatedMonitorResponses,
Detector detector,
Expand All @@ -90,13 +95,13 @@ public void upsertWorkflow(
}
ChainedMonitorFindings chainedMonitorFindings = null;
String cmfMonitorId = null;
if(addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) {
List<String> bucketMonitorIds = addedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList());
if(!updatedMonitors.isEmpty()) {
bucketMonitorIds.addAll(updatedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList()));
if (addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) {
List<IndexMonitorResponse> monitorResponses = new ArrayList<>(addedMonitorResponses);
if (updatedMonitorResponses != null) {
monitorResponses.addAll(updatedMonitorResponses);
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, bucketMonitorIds);
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down Expand Up @@ -154,7 +159,7 @@ private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Dete
return delegate;
}
).collect(Collectors.toList());

Sequence sequence = new Sequence(delegates);
CompositeInput compositeInput = new CompositeInput(sequence);

Expand Down Expand Up @@ -185,21 +190,5 @@ private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Dete
null
);
}

private Map<String, String> mapMonitorIds(List<IndexMonitorResponse> monitorResponses) {
return monitorResponses.stream().collect(
Collectors.toMap(
// In the case of bucket level monitors rule id is trigger id
it -> {
if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) {
return it.getMonitor().getTriggers().get(0).getId();
} else {
return Detector.DOC_LEVEL_MONITOR;
}
},
IndexMonitorResponse::getId
)
);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ protected void createRuleTopicIndex(String detectorType, String additionalMappin
assertEquals(RestStatus.OK, restStatus(response));
}
}

protected void verifyWorkflow(Map<String, Object> detectorMap, List<String> monitorIds, int expectedDelegatesNum) throws IOException{
String workflowId = ((List<String>) detectorMap.get("workflow_ids")).get(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ public static Detector randomDetector(List<String> rules, String detectorType) {
public static Detector randomDetectorWithInputs(List<DetectorInput> inputs) {
return randomDetector(null, null, null, inputs, List.of(), null, null, null, null);
}

public static Detector randomDetectorWithInputsAndTriggers(List<DetectorInput> inputs, List<DetectorTrigger> triggers) {
return randomDetector(null, null, null, inputs, List.of(), null, null, null, null);
}
public static Detector randomDetectorWithInputs(List<DetectorInput> inputs, String detectorType) {
return randomDetector(null, detectorType, null, inputs, List.of(), null, null, null, null);
}
Expand All @@ -84,9 +88,6 @@ public static Detector randomDetectorWithTriggers(List<String> rules, List<Detec
rules.stream().map(DetectorRule::new).collect(Collectors.toList()));
return randomDetector(null, null, null, List.of(input), triggers, null, null, null, null);
}
public static Detector randomDetectorWithInputsAndTriggers(List<DetectorInput> inputs, List<DetectorTrigger> triggers) {
return randomDetector(null, null, null, inputs, triggers, null, null, null, null);
}

public static Detector randomDetectorWithTriggers(List<String> rules, List<DetectorTrigger> triggers, String detectorType, DetectorInput input) {
return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null);
Expand Down
Loading

0 comments on commit 655064b

Please sign in to comment.