From 2f0abe6d85b0e527e04c44c5df33960c2876c344 Mon Sep 17 00:00:00 2001 From: Stevan Buzejic <30922513+stevanbz@users.noreply.github.com> Date: Wed, 9 Nov 2022 19:13:15 +0100 Subject: [PATCH] Creates bucket level monitors for rules containing aggregations (#92) Signed-off-by: Stevan Buzejic --- .../securityanalytics/model/Detector.java | 36 +- .../securityanalytics/model/Rule.java | 55 ++- .../rules/backend/AggregationBuilders.java | 56 +++ .../rules/backend/OSQueryBackend.java | 153 +++++- .../rules/backend/QueryBackend.java | 5 +- .../TransportIndexDetectorAction.java | 453 +++++++++++++++--- .../transport/TransportIndexRuleAction.java | 2 +- .../SecurityAnalyticsRestTestCase.java | 15 + .../securityanalytics/TestHelpers.java | 104 +++- .../action/IndexDetectorResponseTests.java | 3 +- .../alerts/AlertingServiceTests.java | 7 +- .../findings/FindingServiceTests.java | 8 +- .../resthandler/DetectorRestApiIT.java | 275 +++++++++++ .../resthandler/RuleRestApiIT.java | 52 ++ .../aggregation/AggregationBackendTests.java | 15 +- 15 files changed, 1133 insertions(+), 106 deletions(-) create mode 100644 src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index b8e83801c..5cc391e22 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -4,6 +4,8 @@ */ package org.opensearch.securityanalytics.model; +import java.util.HashMap; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.ParseField; @@ -49,6 +51,8 @@ public class Detector implements Writeable, ToXContentObject { public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ENABLED_TIME_FIELD = "enabled_time"; public static final String ALERTING_MONITOR_ID = "monitor_id"; + + public static final String BUCKET_MONITOR_ID_RULE_ID = "bucket_monitor_id_rule_id"; private static final String RULE_TOPIC_INDEX = "rule_topic_index"; private static final String ALERTS_INDEX = "alert_index"; @@ -59,6 +63,9 @@ public class Detector implements Writeable, ToXContentObject { public static final String DETECTORS_INDEX = ".opensearch-sap-detectors-config"; + // Used as a key in rule-monitor map for the purpose of easy detection of the doc level monitor + public static final String DOC_LEVEL_MONITOR = "-1"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Detector.class, new ParseField(DETECTOR_TYPE), @@ -90,6 +97,8 @@ public class Detector implements Writeable, ToXContentObject { private List monitorIds; + private Map ruleIdMonitorIdMap; + private String ruleIndex; private String alertsIndex; @@ -108,7 +117,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule Instant lastUpdateTime, Instant enabledTime, DetectorType detectorType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -129,6 +138,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.alertsHistoryIndexPattern = alertsHistoryIndexPattern; this.findingsIndex = findingsIndex; this.findingsIndexPattern = findingsIndexPattern; + this.ruleIdMonitorIdMap = rulePerMonitor; if (enabled) { Objects.requireNonNull(enabledTime); @@ -154,7 +164,9 @@ public Detector(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readString(), - sin.readString()); + sin.readString(), + sin.readMap(StreamInput::readString, StreamInput::readString) + ); } @Override @@ -186,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeStringCollection(monitorIds); out.writeString(ruleIndex); + + out.writeMap(ruleIdMonitorIdMap, StreamOutput::writeString, StreamOutput::writeString); } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -269,6 +283,7 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten } builder.field(ALERTING_MONITOR_ID, monitorIds); + builder.field(BUCKET_MONITOR_ID_RULE_ID, ruleIdMonitorIdMap); builder.field(RULE_TOPIC_INDEX, ruleIndex); builder.field(ALERTS_INDEX, alertsIndex); builder.field(ALERTS_HISTORY_INDEX, alertsHistoryIndex); @@ -313,6 +328,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws List inputs = new ArrayList<>(); List triggers = new ArrayList<>(); List monitorIds = new ArrayList<>(); + Map rulePerMonitor = new HashMap<>(); + String ruleIndex = null; String alertsIndex = null; String alertsHistoryIndex = null; @@ -391,6 +408,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws monitorIds.add(monitorId); } break; + case BUCKET_MONITOR_ID_RULE_ID: + rulePerMonitor= xcp.mapStrings(); + break; case RULE_TOPIC_INDEX: ruleIndex = xcp.text(); break; @@ -438,7 +458,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws alertsHistoryIndex, alertsHistoryIndexPattern, findingsIndex, - findingsIndexPattern); + findingsIndexPattern, + rulePerMonitor); } public static Detector readFrom(StreamInput sin) throws IOException { @@ -521,6 +542,8 @@ public void setUser(User user) { this.user = user; } + public Map getRuleIdMonitorIdMap() {return ruleIdMonitorIdMap; } + public void setId(String id) { this.id = id; } @@ -568,6 +591,13 @@ public void setInputs(List inputs) { public void setMonitorIds(List monitorIds) { this.monitorIds = monitorIds; } + public void setRuleIdMonitorIdMap(Map ruleIdMonitorIdMap) { + this.ruleIdMonitorIdMap = ruleIdMonitorIdMap; + } + + public String getDocLevelMonitorId() { + return ruleIdMonitorIdMap.get(DOC_LEVEL_MONITOR); + } @Override public boolean equals(Object o) { diff --git a/src/main/java/org/opensearch/securityanalytics/model/Rule.java b/src/main/java/org/opensearch/securityanalytics/model/Rule.java index 5fa63eeef..e54c2ccce 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Rule.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Rule.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.model; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.ParseField; @@ -16,6 +17,11 @@ import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentParserUtils; +import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; +import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; +import org.opensearch.securityanalytics.rules.condition.ConditionItem; +import org.opensearch.securityanalytics.rules.exceptions.SigmaError; +import org.opensearch.securityanalytics.rules.objects.SigmaCondition; import org.opensearch.securityanalytics.rules.objects.SigmaRule; import java.io.IOException; @@ -56,6 +62,7 @@ public class Rule implements Writeable, ToXContentObject { public static final String PRE_PACKAGED_RULES_INDEX = ".opensearch-sap-pre-packaged-rules-config"; public static final String CUSTOM_RULES_INDEX = ".opensearch-sap-custom-rules-config"; + public static final String AGGREGATION_QUERIES = "aggregationQueries"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Rule.class, @@ -95,10 +102,12 @@ public class Rule implements Writeable, ToXContentObject { private String rule; + private List aggregationQueries; + public Rule(String id, Long version, String title, String category, String logSource, String description, List references, List tags, String level, List falsePositives, String author, String status, Instant date, - List queries, List queryFieldNames, String rule) { + List queries, List queryFieldNames, String rule, List aggregationQueries) { this.id = id != null? id: NO_ID; this.version = version != null? version: NO_VERSION; @@ -121,10 +130,11 @@ public Rule(String id, Long version, String title, String category, String logSo this.queries = queries; this.queryFieldNames = queryFieldNames; this.rule = rule; + this.aggregationQueries = aggregationQueries; } public Rule(String id, Long version, SigmaRule rule, String category, - List queries, List queryFieldNames, String original) { + List queries, List queryFieldNames, String original) { this( id, version, @@ -141,9 +151,11 @@ public Rule(String id, Long version, SigmaRule rule, String category, rule.getAuthor(), rule.getStatus().toString(), Instant.ofEpochMilli(rule.getDate().getTime()), - queries.stream().map(Value::new).collect(Collectors.toList()), + queries.stream().filter(query -> !(query instanceof AggregationQueries)).map(query -> new Value(query.toString())).collect(Collectors.toList()), queryFieldNames.stream().map(Value::new).collect(Collectors.toList()), - original); + original, + // If one of the queries is AggregationQuery -> the whole rule can be considered as Agg + queries.stream().filter(query -> query instanceof AggregationQueries).map(it -> new Value(it.toString())).collect(Collectors.toList())); } public Rule(StreamInput sin) throws IOException { @@ -163,7 +175,9 @@ public Rule(StreamInput sin) throws IOException { sin.readInstant(), sin.readList(Value::readFrom), sin.readList(Value::readFrom), - sin.readString()); + sin.readString(), + sin.readList(Value::readFrom) + ); } @Override @@ -190,6 +204,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(queryFieldNames); out.writeString(rule); + out.writeCollection(aggregationQueries); } @Override @@ -233,6 +248,10 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten queryFieldNamesArray = queryFieldNames.toArray(queryFieldNamesArray); builder.field(QUERY_FIELD_NAMES, queryFieldNamesArray); + Value[] aggregationsArray = new Value[]{}; + aggregationsArray = aggregationQueries.toArray(aggregationsArray); + builder.field(AGGREGATION_QUERIES, aggregationsArray); + builder.field(RULE, rule); if (params.paramAsBoolean("with_type", false)) { builder.endObject(); @@ -278,6 +297,7 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE List queries = new ArrayList<>(); List queryFields = new ArrayList<>(); String original = null; + List aggregationQueries = new ArrayList<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -342,6 +362,11 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE case RULE: original = xcp.text(); break; + case AGGREGATION_QUERIES: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + aggregationQueries.add(Value.parse(xcp)); + } default: xcp.skipChildren(); } @@ -363,7 +388,8 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE date, queries, queryFields, - Objects.requireNonNull(original, "Rule String is null") + Objects.requireNonNull(original, "Rule String is null"), + aggregationQueries ); } @@ -442,4 +468,21 @@ public List getQueries() { public List getQueryFieldNames() { return queryFieldNames; } + + public List getAggregationQueries() { return aggregationQueries; } + + public boolean isAggregationRule() { + return aggregationQueries != null && !aggregationQueries.isEmpty(); + } + + public List getAggregationItemsFromRule () throws SigmaError { + SigmaRule sigmaRule = SigmaRule.fromYaml(rule, true); + List aggregationItems = new ArrayList<>(); + for (SigmaCondition condition: sigmaRule.getDetection().getParsedCondition()) { + Pair parsedItems = condition.parsed(); + AggregationItem aggItem = parsedItems.getRight(); + aggregationItems.add(aggItem); + } + return aggregationItems; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java new file mode 100644 index 000000000..3927186fb --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.rules.backend; + +import java.util.Locale; +import org.apache.commons.lang3.NotImplementedException; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; +import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; + +public final class AggregationBuilders { + + /** + * Finds the builder aggregation based on the forwarded function + * + * @param aggregationFunction Aggregation function + * @param name Name of the aggregation + * @return Aggregation builder + */ + public static AggregationBuilder getAggregationBuilderByFunction(String aggregationFunction, String name) { + AggregationBuilder aggregationBuilder; + switch (aggregationFunction.toLowerCase(Locale.ROOT)) { + case AvgAggregationBuilder.NAME: + aggregationBuilder = new AvgAggregationBuilder(name).field(name); + break; + case MaxAggregationBuilder.NAME: + aggregationBuilder = new MaxAggregationBuilder(name).field(name); + break; + case MedianAbsoluteDeviationAggregationBuilder.NAME: + aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder(name).field(name); + break; + case MinAggregationBuilder.NAME: + aggregationBuilder = new MinAggregationBuilder(name).field(name); + break; + case SumAggregationBuilder.NAME: + aggregationBuilder = new SumAggregationBuilder(name).field(name); + break; + case TermsAggregationBuilder.NAME: + aggregationBuilder = new TermsAggregationBuilder(name).field(name); + break; + case "count": + aggregationBuilder = new ValueCountAggregationBuilder(name).field(name); + break; + default: + throw new NotImplementedException(String.format(Locale.getDefault(), "Aggregation %s not supported by the backend", aggregationFunction)); + } + return aggregationBuilder; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java index a84c82ae9..b243c884c 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java @@ -4,6 +4,22 @@ */ package org.opensearch.securityanalytics.rules.backend; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.UUIDs; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentParserUtils; +import org.opensearch.commons.alerting.aggregation.bucketselectorext.BucketSelectorExtAggregationBuilder; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; import org.opensearch.securityanalytics.rules.condition.ConditionAND; import org.opensearch.securityanalytics.rules.condition.ConditionFieldEqualsValueExpression; @@ -25,7 +41,6 @@ import org.apache.commons.lang3.NotImplementedException; import java.io.IOException; -import java.io.Serializable; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -86,6 +101,8 @@ public class OSQueryBackend extends QueryBackend { private String bucketTriggerQuery; + private String bucketTriggerScript; + private static final String groupExpression = "(%s)"; private static final Map compareOperators = Map.of( SigmaCompareExpression.CompareOperators.GT, "gt", @@ -121,9 +138,10 @@ public OSQueryBackend(String ruleCategory, boolean collectErrors, boolean enable this.unboundReExpression = "%s: /%s/"; this.compareOpExpression = "\"%s\" \"%s\" %s"; this.valExpCount = 0; - this.aggQuery = "\"aggs\":{\"%s\":{\"terms\":{\"field\":\"%s\"},\"aggs\":{\"%s\":{\"%s\":{\"field\":\"%s\"}}}}}"; - this.aggCountQuery = "\"aggs\":{\"%s\":{\"terms\":{\"field\":\"%s\"}}}"; + this.aggQuery = "{\"%s\":{\"terms\":{\"field\":\"%s\"},\"aggs\":{\"%s\":{\"%s\":{\"field\":\"%s\"}}}}}"; + this.aggCountQuery = "{\"%s\":{\"terms\":{\"field\":\"%s\"}}}"; this.bucketTriggerQuery = "{\"buckets_path\":{\"%s\":\"%s\"},\"parent_bucket_path\":\"%s\",\"script\":{\"source\":\"params.%s %s %s\",\"lang\":\"painless\"}}"; + this.bucketTriggerScript = "params.%s %s %s"; } @Override @@ -346,24 +364,48 @@ public Object convertConditionValQueryExpr(ConditionValueExpression condition) { }*/ @Override - public Object convertAggregation(AggregationItem aggregation) { + public AggregationQueries convertAggregation(AggregationItem aggregation) { String fmtAggQuery; String fmtBucketTriggerQuery; + TermsAggregationBuilder aggBuilder = new TermsAggregationBuilder("result_agg"); + BucketSelectorExtAggregationBuilder condition; + String bucketTriggerSelectorId = UUIDs.base64UUID(); + if (aggregation.getAggFunction().equals("count")) { + String fieldName; if (aggregation.getAggField().equals("*") && aggregation.getGroupByField() == null) { + fieldName = "_index"; fmtAggQuery = String.format(Locale.getDefault(), aggCountQuery, "result_agg", "_index"); } else { + fieldName = aggregation.getGroupByField(); fmtAggQuery = String.format(Locale.getDefault(), aggCountQuery, "result_agg", aggregation.getGroupByField()); } + aggBuilder.field(fieldName); fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, "_cnt", "_cnt", "result_agg", "_cnt", aggregation.getCompOperator(), aggregation.getThreshold()); + + Script script = new Script(String.format(Locale.getDefault(), bucketTriggerScript, "_cnt", aggregation.getCompOperator(), aggregation.getThreshold())); + condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap("_cnt", "_cnt"), script, "result_agg", null); } else { fmtAggQuery = String.format(Locale.getDefault(), aggQuery, "result_agg", aggregation.getGroupByField(), aggregation.getAggField(), aggregation.getAggFunction(), aggregation.getAggField()); fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, aggregation.getAggField(), aggregation.getAggField(), "result_agg", aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold()); + + // Add subaggregation + AggregationBuilder subAgg = AggregationBuilders.getAggregationBuilderByFunction(aggregation.getAggFunction(), aggregation.getAggField()); + if (subAgg != null) { + aggBuilder.field(aggregation.getGroupByField()).subAggregation(subAgg); + } + + Script script = new Script(String.format(Locale.getDefault(), bucketTriggerScript, aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold())); + condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap(aggregation.getAggField(), aggregation.getAggField()), script, "result_agg", null); } - AggregationQueries aggQueries = new AggregationQueries(); - aggQueries.setAggQuery(fmtAggQuery); - aggQueries.setBucketTriggerQuery(fmtBucketTriggerQuery); - return aggQueries; + + AggregationQueries aggregationQueries = new AggregationQueries(); + aggregationQueries.setAggQuery(fmtAggQuery); + aggregationQueries.setBucketTriggerQuery(fmtBucketTriggerQuery); + aggregationQueries.setAggBuilder(aggBuilder); + aggregationQueries.setCondition(condition); + + return aggregationQueries; } private boolean comparePrecedence(ConditionType outer, ConditionType inner) { @@ -416,26 +458,111 @@ private String getFinalValueField() { return field; } - public static class AggregationQueries implements Serializable { + public static class AggregationQueries implements Writeable, ToXContentObject { + private static final String AGG_QUERY = "aggQuery"; + private static final String BUCKET_TRIGGER_QUERY = "bucketTriggerQuery"; + + public AggregationQueries() { + } + + public AggregationQueries(StreamInput in) throws IOException { + this.aggQuery = in.readString(); + this.bucketTriggerQuery = in.readString(); + } + + public static AggregationQueries docParse(XContentParser xcp) throws IOException{ + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); + return AggregationQueries.parse(xcp); + } + + public static AggregationQueries parse(XContentParser xcp) throws IOException { + String aggQuery = null; + String bucketTriggerQuery = null; + + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case AGG_QUERY: + aggQuery = xcp.text(); + break; + case BUCKET_TRIGGER_QUERY: + bucketTriggerQuery = xcp.text(); + break; + default: + xcp.skipChildren(); + } + } + AggregationQueries aggregationQueries = new AggregationQueries(); + aggregationQueries.setAggQuery(aggQuery); + aggregationQueries.setBucketTriggerQuery(bucketTriggerQuery); + + return aggregationQueries; + } private String aggQuery; + private AggregationBuilder aggBuilder; + private String bucketTriggerQuery; + private BucketSelectorExtAggregationBuilder condition; + + public String getAggQuery() { + return aggQuery; + } + public void setAggQuery(String aggQuery) { this.aggQuery = aggQuery; } - public String getAggQuery() { - return aggQuery; + public AggregationBuilder getAggBuilder() { + return aggBuilder; } - public void setBucketTriggerQuery(String bucketTriggerQuery) { - this.bucketTriggerQuery = bucketTriggerQuery; + public void setAggBuilder(AggregationBuilder aggBuilder) { + this.aggBuilder = aggBuilder; } public String getBucketTriggerQuery() { return bucketTriggerQuery; } + + public void setBucketTriggerQuery(String bucketTriggerQuery) { + this.bucketTriggerQuery = bucketTriggerQuery; + } + + public BucketSelectorExtAggregationBuilder getCondition() { + return condition; + } + + public void setCondition(BucketSelectorExtAggregationBuilder condition) { + this.condition = condition; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return createXContentBuilder(builder); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(aggQuery); + out.writeString(bucketTriggerQuery); + } + + private XContentBuilder createXContentBuilder(XContentBuilder builder) throws IOException { + return builder.startObject().field(AGG_QUERY, aggQuery).field(BUCKET_TRIGGER_QUERY, bucketTriggerQuery).endObject(); + } + + public String toString() { + try { + return BytesReference.bytes(this.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)).utf8ToString(); + } catch (IOException ex) { + throw new OpenSearchParseException("failed to convert source to a json string", new Object[0]); + } + } } } diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java index 8a4d00a52..ebb68faf8 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java @@ -4,7 +4,10 @@ */ package org.opensearch.securityanalytics.rules.backend; +import org.opensearch.commons.alerting.aggregation.bucketselectorext.BucketSelectorExtAggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; +import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; import org.opensearch.securityanalytics.rules.condition.ConditionAND; import org.opensearch.securityanalytics.rules.condition.ConditionFieldEqualsValueExpression; import org.opensearch.securityanalytics.rules.condition.ConditionItem; @@ -261,5 +264,5 @@ public Object convertConditionVal(ConditionValueExpression condition) throws Sig /* public abstract Object convertConditionValQueryExpr(ConditionValueExpression condition);*/ - public abstract Object convertAggregation(AggregationItem aggregation); + public abstract AggregationQueries convertAggregation(AggregationItem aggregation) throws SigmaError; } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index a6515a32e..949415db7 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -6,6 +6,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -16,9 +20,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRunnable; +import org.opensearch.action.StepListener; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.get.GetRequest; @@ -28,8 +34,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.WriteRequest.RefreshPolicy; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; @@ -46,13 +54,18 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.DeleteMonitorRequest; +import org.opensearch.commons.alerting.action.DeleteMonitorResponse; import org.opensearch.commons.alerting.action.IndexMonitorRequest; import org.opensearch.commons.alerting.action.IndexMonitorResponse; +import org.opensearch.commons.alerting.model.BucketLevelTrigger; import org.opensearch.commons.alerting.model.DataSources; import org.opensearch.commons.alerting.model.DocLevelMonitorInput; import org.opensearch.commons.alerting.model.DocLevelQuery; import org.opensearch.commons.alerting.model.DocumentLevelTrigger; import org.opensearch.commons.alerting.model.Monitor; +import org.opensearch.commons.alerting.model.Monitor.MonitorType; +import org.opensearch.commons.alerting.model.SearchInput; import org.opensearch.commons.alerting.model.action.Action; import org.opensearch.commons.authuser.User; import org.opensearch.index.query.QueryBuilder; @@ -60,6 +73,7 @@ import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestStatus; import org.opensearch.script.Script; import org.opensearch.search.SearchHit; @@ -76,6 +90,10 @@ import org.opensearch.securityanalytics.model.DetectorTrigger; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.Value; +import org.opensearch.securityanalytics.rules.backend.OSQueryBackend; +import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; +import org.opensearch.securityanalytics.rules.backend.QueryBackend; +import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.util.DetectorIndices; import org.opensearch.securityanalytics.util.IndexUtils; @@ -159,7 +177,176 @@ protected void doExecute(Task task, IndexDetectorRequest request, ActionListener asyncAction.start(); } - private void createAlertingMonitorFromQueries(Pair>> logIndexToQueries, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) { + private void createMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener>listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { + List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( + Collectors.toList()); + List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( + Collectors.toList()); + + List monitorRequests = new ArrayList<>(); + + if (!docLevelRules.isEmpty()) { + monitorRequests.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } + if (!bucketLevelRules.isEmpty()) { + monitorRequests.addAll(buildBucketLevelMonitorRequests(Pair.of(index, bucketLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } + // Do nothing if detector doesn't have any monitor + if(monitorRequests.isEmpty()){ + return; + } + + List monitorResponses = new ArrayList<>(); + StepListener addFirstMonitorStep = new StepListener(); + + // Indexing monitors in two steps in order to prevent all shards failed error from alerting + // https://github.com/opensearch-project/alerting/issues/646 + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, addFirstMonitorStep); + addFirstMonitorStep.whenComplete(addedFirstMonitorResponse -> { + monitorResponses.add(addedFirstMonitorResponse); + int numberOfUnprocessedResponses = monitorRequests.size() - 1; + if(numberOfUnprocessedResponses == 0){ + listener.onResponse(monitorResponses); + } else { + GroupedActionListener monitorResponseListener = new GroupedActionListener( + new ActionListener>() { + @Override + public void onResponse(Collection indexMonitorResponse) { + monitorResponses.addAll(indexMonitorResponse.stream().collect(Collectors.toList())); + listener.onResponse(monitorResponses); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, numberOfUnprocessedResponses); + + for(int i = 1; i < monitorRequests.size(); i++){ + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(i), namedWriteableRegistry, monitorResponseListener); + } + } + }, + listener::onFailure + ); + } + + private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { + List monitorsToBeUpdated = new ArrayList<>(); + + List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( + Collectors.toList()); + List monitorsToBeAdded = new ArrayList<>(); + // Process bucket level monitors + if (!bucketLevelRules.isEmpty()) { + List ruleCategories = bucketLevelRules.stream().map(Pair::getRight).map(Rule::getCategory).distinct().collect( + Collectors.toList()); + Map queryBackendMap = new HashMap<>(); + for(String category: ruleCategories){ + queryBackendMap.put(category, new OSQueryBackend(category, true, true)); + } + + // Pair of RuleId - MonitorId for existing monitors of the detector + Map monitorPerRule = detector.getRuleIdMonitorIdMap(); + + for (Pair query: bucketLevelRules) { + Rule rule = query.getRight(); + if(rule.getAggregationQueries() != null){ + // Detect if the monitor should be added or updated + if (monitorPerRule.containsKey(rule.getId())) { + String monitorId = monitorPerRule.get(rule.getId()); + monitorsToBeUpdated.add(createBucketLevelMonitorRequest(query.getRight(), + index, + detector, + refreshPolicy, + monitorId, + Method.PUT, + queryBackendMap.get(rule.getCategory()))); + } else { + monitorsToBeAdded.add(createBucketLevelMonitorRequest(query.getRight(), + index, + detector, + refreshPolicy, + Monitor.NO_ID, + Method.POST, + queryBackendMap.get(rule.getCategory()))); + } + } + } + } + + List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( + Collectors.toList()); + + // Process doc level monitors + if (!docLevelRules.isEmpty()) { + if (detector.getDocLevelMonitorId() == null) { + monitorsToBeAdded.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } else { + monitorsToBeUpdated.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); + } + } + + List monitorIdsToBeDeleted = detector.getRuleIdMonitorIdMap().values().stream().collect(Collectors.toList()); + monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( + Collectors.toList())); + + updateAlertingMonitors(monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + } + + /** + * Update list of monitors for the given detector + * Executed in a steps: + * 1. Add new monitors; + * 2. Update existing monitors; + * 3. Delete the monitors omitted from request + * 4. Respond with updated list of monitors + * @param monitorsToBeAdded Newly added monitors by the user + * @param monitorsToBeUpdated Existing monitors that will be updated + * @param monitorsToBeDeleted Monitors omitted by the user + * @param refreshPolicy + * @param listener Listener that accepts the list of updated monitors if the action was successful + */ + private void updateAlertingMonitors( + List monitorsToBeAdded, + List monitorsToBeUpdated, + List monitorsToBeDeleted, + RefreshPolicy refreshPolicy, + ActionListener> listener + ) { + List updatedMonitors = new ArrayList<>(); + + // Update monitor steps + StepListener> addNewMonitorsStep = new StepListener(); + executeMonitorActionRequest(monitorsToBeAdded, addNewMonitorsStep); + // 1. Add new alerting monitors (for the rules that didn't exist previously) + addNewMonitorsStep.whenComplete(addNewMonitorsResponse -> { + if(addNewMonitorsResponse != null && !addNewMonitorsResponse.isEmpty()) { + updatedMonitors.addAll(addNewMonitorsResponse); + } + + StepListener> updateMonitorsStep = new StepListener<>(); + executeMonitorActionRequest(monitorsToBeUpdated, updateMonitorsStep); + // 2. Update existing alerting monitors (based on the common rules) + updateMonitorsStep.whenComplete(updateMonitorResponse -> { + if(updateMonitorResponse!=null && !updateMonitorResponse.isEmpty()) { + updatedMonitors.addAll(updateMonitorResponse); + } + + StepListener> deleteMonitorStep = new StepListener<>(); + deleteAlertingMonitors(monitorsToBeDeleted, refreshPolicy, deleteMonitorStep); + // 3. Delete alerting monitors (rules that are not provided by the user) + deleteMonitorStep.whenComplete(deleteMonitorResponses -> + // Return list of all updated + newly added monitors + listener.onResponse(updatedMonitors), + // Handle delete monitors (step 3) + listener::onFailure); + }, // Handle update monitor failed (step 2) + listener::onFailure); + // Handle add failed (step 1) + }, listener::onFailure); + } + + private IndexMonitorRequest createDocLevelMonitorRequest(Pair>> logIndexToQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) { List docLevelMonitorInputs = new ArrayList<>(); List docLevelQueries = new ArrayList<>(); @@ -196,71 +383,175 @@ private void createAlertingMonitorFromQueries(Pair>> logIndexToQueries, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) { - List docLevelMonitorInputs = new ArrayList<>(); + private List buildBucketLevelMonitorRequests(Pair>> logIndexToQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) throws IOException, SigmaError { + List ruleCategories = logIndexToQueries.getRight().stream().map(Pair::getRight).map(Rule::getCategory).distinct().collect( + Collectors.toList()); + Map queryBackendMap = new HashMap<>(); - List docLevelQueries = new ArrayList<>(); + for(String category: ruleCategories){ + queryBackendMap.put(category, new OSQueryBackend(category, true, true)); + } - for (Pair query: logIndexToQueries.getRight()) { - String id = query.getLeft(); + List monitorRequests = new ArrayList<>(); + for (Pair query: logIndexToQueries.getRight()) { Rule rule = query.getRight(); - String name = query.getLeft(); - String actualQuery = rule.getQueries().get(0).getValue(); + // Creating bucket level monitor per each aggregation rule + if(rule.getAggregationQueries() != null){ + monitorRequests.add(createBucketLevelMonitorRequest( + query.getRight(), + logIndexToQueries.getLeft(), + detector, + refreshPolicy, + Monitor.NO_ID, + Method.POST, + queryBackendMap.get(rule.getCategory()))); + } + } + return monitorRequests; + } - List tags = new ArrayList<>(); - tags.add(rule.getLevel()); - tags.add(rule.getCategory()); - tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); + private IndexMonitorRequest createBucketLevelMonitorRequest( + Rule rule, + String index, + Detector detector, + WriteRequest.RefreshPolicy refreshPolicy, + String monitorId, + RestRequest.Method restMethod, + QueryBackend queryBackend + ) throws SigmaError { + AggregationQueries aggregationQueries = queryBackend.convertAggregation(rule.getAggregationItemsFromRule().get(0)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .seqNoAndPrimaryTerm(true) + .version(true) + // Build query string filter + .query(QueryBuilders.queryStringQuery(rule.getQueries().get(0).getValue())) + .aggregation(aggregationQueries.getAggBuilder()); + + List bucketLevelMonitorInputs = new ArrayList<>(); + bucketLevelMonitorInputs.add(new SearchInput(Arrays.asList(index), searchSourceBuilder)); + + List triggers = new ArrayList<>(); + BucketLevelTrigger bucketLevelTrigger = new BucketLevelTrigger(rule.getId(), rule.getTitle(), rule.getLevel(), aggregationQueries.getCondition(), + Collections.emptyList()); + triggers.add(bucketLevelTrigger); + + /** TODO - Think how to use detector trigger + List detectorTriggers = detector.getTriggers(); + for (DetectorTrigger detectorTrigger: detectorTriggers) { + String id = detectorTrigger.getId(); + String name = detectorTrigger.getName(); + String severity = detectorTrigger.getSeverity(); + List actions = detectorTrigger.getActions(); + Script condition = detectorTrigger.convertToCondition(); + + BucketLevelTrigger bucketLevelTrigger1 = new BucketLevelTrigger(id, name, severity, condition, actions); + triggers.add(bucketLevelTrigger1); + } **/ + + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), + MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), + new DataSources(detector.getRuleIndex(), + detector.getFindingsIndex(), + detector.getFindingsIndexPattern(), + detector.getAlertsIndex(), + detector.getAlertsHistoryIndex(), + detector.getAlertsHistoryIndexPattern(), + DetectorMonitorConfig.getRuleIndexMappingsByType(detector.getDetectorType()), + false), PLUGIN_OWNER_FIELD); + + return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null); + } - DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, actualQuery, tags); - docLevelQueries.add(docLevelQuery); + /** + * Executes monitor related requests (PUT/POST) - returns the response once all the executions are completed + * @param indexMonitors Monitors to be updated/added + * @param listener actionListener for handling updating/creating monitors + */ + public void executeMonitorActionRequest( + List indexMonitors, + ActionListener> listener) { + + // In the case of not provided monitors, just return empty list + if(indexMonitors == null || indexMonitors.isEmpty()) { + listener.onResponse(new ArrayList<>()); + return; } - DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), List.of(logIndexToQueries.getKey()), docLevelQueries); - docLevelMonitorInputs.add(docLevelMonitorInput); - List triggers = new ArrayList<>(); - List detectorTriggers = detector.getTriggers(); + GroupedActionListener monitorResponseListener = new GroupedActionListener( + new ActionListener>() { + @Override + public void onResponse(Collection indexMonitorResponse) { + listener.onResponse(indexMonitorResponse.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, indexMonitors.size()); - for (DetectorTrigger detectorTrigger: detectorTriggers) { - String id = detectorTrigger.getId(); - String name = detectorTrigger.getName(); - String severity = detectorTrigger.getSeverity(); - List actions = detectorTrigger.getActions(); - Script condition = detectorTrigger.convertToCondition(); + // Persist monitors sequentially + for (IndexMonitorRequest req: indexMonitors) { + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, req, namedWriteableRegistry, monitorResponseListener); + } + } - triggers.add(new DocumentLevelTrigger(id, name, severity, actions, condition)); + /** + * Deletes the alerting monitors based on the given ids and notifies the listener that will be notified once all monitors have been deleted + * @param monitorIds monitor ids to be deleted + * @param refreshPolicy + * @param listener listener that will be notified once all the monitors are being deleted + */ + private void deleteAlertingMonitors(List monitorIds, WriteRequest.RefreshPolicy refreshPolicy, ActionListener> listener){ + if (monitorIds == null || monitorIds.isEmpty()) { + listener.onResponse(new ArrayList<>()); + return; } + ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection responses) { + SetOnce errorStatusSupplier = new SetOnce<>(); + if (responses.stream().filter(response -> { + if (response.getStatus() != RestStatus.OK) { + log.error("Monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); + errorStatusSupplier.trySet(response.getStatus()); + return true; + } + return false; + }).count() > 0) { + listener.onFailure(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); + } + listener.onResponse(responses.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, monitorIds.size()); - Monitor monitor = new Monitor(detector.getMonitorIds().get(0), Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), - Monitor.MonitorType.DOC_LEVEL_MONITOR, detector.getUser(), 1, docLevelMonitorInputs, triggers, Map.of(), - new DataSources(detector.getRuleIndex(), - detector.getFindingsIndex(), - detector.getFindingsIndexPattern(), - detector.getAlertsIndex(), - detector.getAlertsHistoryIndex(), - detector.getAlertsHistoryIndexPattern(), - DetectorMonitorConfig.getRuleIndexMappingsByType(detector.getDetectorType()), - true), PLUGIN_OWNER_FIELD); - - IndexMonitorRequest indexMonitorRequest = new IndexMonitorRequest(detector.getMonitorIds().get(0), SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, RestRequest.Method.PUT, monitor, null); - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, indexMonitorRequest, namedWriteableRegistry, listener); + for (String monitorId : monitorIds) { + deleteAlertingMonitor(monitorId, refreshPolicy, deletesListener); + } + } + private void deleteAlertingMonitor(String monitorId, WriteRequest.RefreshPolicy refreshPolicy, ActionListener listener) { + DeleteMonitorRequest request = new DeleteMonitorRequest(monitorId, refreshPolicy); + AlertingPluginInterface.INSTANCE.deleteMonitor((NodeClient) client, request, listener); } private void onCreateMappingsResponse(CreateIndexResponse response) throws IOException { @@ -384,8 +675,9 @@ public void onResponse(CreateIndexResponse createIndexResponse) { initRuleIndexAndImportRules(request, new ActionListener<>() { @Override - public void onResponse(IndexMonitorResponse indexMonitorResponse) { - request.getDetector().setMonitorIds(List.of(indexMonitorResponse.getId())); + public void onResponse(List monitorResponses) { + request.getDetector().setMonitorIds(getMonitorIds(monitorResponses)); + request.getDetector().setRuleIdMonitorIdMap(mapMonitorIds(monitorResponses)); try { indexDetector(); } catch (IOException e) { @@ -465,6 +757,7 @@ void onGetResponse(Detector currentDetector, User user) { request.getDetector().setEnabledTime(currentDetector.getEnabledTime()); } request.getDetector().setMonitorIds(currentDetector.getMonitorIds()); + request.getDetector().setRuleIdMonitorIdMap(currentDetector.getRuleIdMonitorIdMap()); Detector detector = request.getDetector(); String ruleTopic = detector.getDetectorType(); @@ -487,8 +780,9 @@ void onGetResponse(Detector currentDetector, User user) { public void onResponse(CreateIndexResponse createIndexResponse) { initRuleIndexAndImportRules(request, new ActionListener<>() { @Override - public void onResponse(IndexMonitorResponse indexMonitorResponse) { - request.getDetector().setMonitorIds(List.of(indexMonitorResponse.getId())); + public void onResponse(List monitorResponses) { + request.getDetector().setMonitorIds(getMonitorIds(monitorResponses)); + request.getDetector().setRuleIdMonitorIdMap(mapMonitorIds(monitorResponses)); try { indexDetector(); } catch (IOException e) { @@ -514,7 +808,7 @@ public void onFailure(Exception e) { } } - public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionListener listener) { + public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionListener> listener) { ruleIndices.initPrepackagedRulesIndex( new ActionListener<>() { @Override @@ -619,7 +913,7 @@ public void onFailure(Exception e) { } @SuppressWarnings("unchecked") - public void importRules(IndexDetectorRequest request, ActionListener listener) { + public void importRules(IndexDetectorRequest request, ActionListener> listener) { final Detector detector = request.getDetector(); final String ruleTopic = detector.getDetectorType(); final DetectorInput detectorInput = detector.getInputs().get(0); @@ -672,15 +966,13 @@ public void onResponse(SearchResponse response) { } else if (detectorInput.getCustomRules().size() > 0) { onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.BAD_REQUEST)); } else { - Pair>> logIndexToQueries = Pair.of(logIndex, queries); - if (request.getMethod() == RestRequest.Method.POST) { - createAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { - updateAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } } - } catch (IOException e) { + } catch (IOException | SigmaError e) { onFailures(e); } } @@ -693,7 +985,7 @@ public void onFailure(Exception e) { } @SuppressWarnings("unchecked") - public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener listener) { + public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener> listener) { final String logIndex = detectorInput.getIndices().get(0); List ruleIds = detectorInput.getCustomRules().stream().map(DetectorRule::getId).collect(Collectors.toList()); @@ -727,14 +1019,12 @@ public void onResponse(SearchResponse response) { queries.add(Pair.of(id, rule)); } - Pair>> logIndexToQueries = Pair.of(logIndex, queries); - if (request.getMethod() == RestRequest.Method.POST) { - createAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { - updateAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } - } catch (IOException ex) { + } catch (IOException | SigmaError ex) { onFailures(ex); } } @@ -798,6 +1088,33 @@ private void finishHim(Detector detector, Exception t) { } })); } + + private List getMonitorIds(List monitorResponses) { + return monitorResponses.stream().map(IndexMonitorResponse::getId).collect( + Collectors.toList()); + } + + /** + * Creates a map of monitor ids. In the case of bucket level monitors pairs are: RuleId - MonitorId + * In the case of doc level monitor pair is DOC_LEVEL_MONITOR(value) - MonitorId + * @param monitorResponses index monitor responses + * @return map of monitor ids + */ + private Map mapMonitorIds(List monitorResponses) { + return monitorResponses.stream().collect( + Collectors.toMap( + // In the case of bucket level monitors rule id is trigger id + it -> { + if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { + return it.getMonitor().getTriggers().get(0).getId(); + } else { + return Detector.DOC_LEVEL_MONITOR; + } + }, + IndexMonitorResponse::getId + ) + ); + } } private void setFilterByEnabled(boolean filterByEnabled) { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexRuleAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexRuleAction.java index 76c5e3b7e..5eb178fe4 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexRuleAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexRuleAction.java @@ -185,7 +185,7 @@ void prepareRuleIndexing() { Set queryFieldNames = backend.getQueryFields().keySet(); Rule ruleDoc = new Rule( NO_ID, NO_VERSION, parsedRule, category, - queries.stream().map(Object::toString).collect(Collectors.toList()), + queries, new ArrayList<>(queryFieldNames), rule ); diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index a20ed73f9..20111d6eb 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -70,6 +70,9 @@ import java.util.stream.Collectors; import static org.opensearch.action.admin.indices.create.CreateIndexRequest.MAPPINGS; +import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; +import static org.opensearch.securityanalytics.util.RuleTopicIndices.ruleTopicIndexMappings; import static org.opensearch.securityanalytics.util.RuleTopicIndices.ruleTopicIndexSettings; public class SecurityAnalyticsRestTestCase extends OpenSearchRestTestCase { @@ -253,6 +256,18 @@ protected List getRandomPrePackagedRules() throws IOException { return hits.stream().map(hit -> hit.get("_id").toString()).collect(Collectors.toList()); } + protected List createAggregationRules () throws IOException { + return new ArrayList<>(Arrays.asList(createRule(productIndexAvgAggRule()), createRule(sumAggregationTestRule()))); + } + + protected String createRule(String rule) throws IOException { + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "windows"), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + return responseBody.get("_id").toString(); + } + protected List getPrePackagedRules(String ruleCategory) throws IOException { String request = "{\n" + " \"from\": 0\n," + diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 22e4f864e..df509fb4e 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -37,7 +37,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.test.OpenSearchTestCase.randomInt; @@ -121,7 +120,7 @@ public static Detector randomDetector(String name, DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", ""); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); } public static Detector randomDetectorWithNoUser() { @@ -133,7 +132,7 @@ public static Detector randomDetectorWithNoUser() { Instant enabledTime = enabled ? Instant.now().truncatedTo(ChronoUnit.MILLIS) : null; Instant lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.MILLIS); - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, null, inputs, Collections.emptyList(),Collections.singletonList(""), "", "", "", "", "", ""); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, null, inputs, Collections.emptyList(),Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); } public static String randomRule() { @@ -165,6 +164,71 @@ public static String randomRule() { "level: high"; } + public static String countAggregationTestRule() { + return " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: valueA\n" + + " fieldB: valueB\n" + + " fieldC: valueC\n" + + " condition: sel | count(*) > 1"; + } + + public static String sumAggregationTestRule() { + return " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + + " fieldC: valueC\n" + + " condition: sel | sum(fieldA) by fieldB > 110"; + } + + public static String productIndexMaxAggRule() { + return " title: Test\n" + + " id: 5f92fff9-82e3-48eb-8fc1-8b133556a551\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + + " fieldC: valueC\n" + + " condition: sel | max(fieldA) by fieldB > 110"; + } + + public static String randomProductDocument(){ + return "{\n" + + " \"fieldA\": 123,\n" + + " \"mappedB\": 111,\n" + + " \"fieldC\": \"valueC\"\n" + + "}\n"; + } + public static String randomEditedRule() { return "title: Remote Encrypting File System Abuse\n" + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + @@ -359,6 +423,40 @@ public static String netFlowMappings() { " }"; } + public static String productIndexMapping(){ + return "\"properties\":{\n" + + " \"fieldA\":{\n" + + " \"type\":\"long\"\n" + + " },\n" + + " \"mappedB\":{\n" + + " \"type\":\"long\"\n" + + " },\n" + + " \"fieldC\":{\n" + + " \"type\":\"keyword\"\n" + + " }\n" + + "}\n" + + "}"; + } + + public static String productIndexAvgAggRule(){ + return " title: Test\n" + + " id: 39f918f3-981b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + + " fieldC: valueC\n" + + " condition: sel | avg(fieldA) by fieldC > 110"; + } + public static String windowsIndexMapping() { return "\"properties\": {\n" + " \"AccessList\": {\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index 84f930d1b..ad6a110e2 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -49,7 +49,8 @@ public void testIndexDetectorPostResponse() throws IOException { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index 13cd48a04..b6df74548 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -7,6 +7,7 @@ import java.time.Instant; import java.time.ZoneId; +import java.util.Collections; import java.util.List; import java.util.Map; import org.opensearch.action.ActionListener; @@ -62,7 +63,8 @@ public void testGetAlerts_success() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -233,7 +235,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index c5c0cb425..6ad0b5a14 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -8,8 +8,10 @@ import java.time.Instant; import java.time.ZoneId; import java.util.ArrayDeque; +import java.util.Collections; import java.util.List; import java.util.Queue; +import java.util.stream.Collectors; import org.opensearch.action.ActionListener; import org.opensearch.client.Client; import org.opensearch.commons.alerting.model.CronSchedule; @@ -61,7 +63,8 @@ public void testGetFindings_success() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -176,7 +179,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index 0e596e443..444a765bb 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -4,6 +4,9 @@ */ package org.opensearch.securityanalytics.resthandler; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import org.apache.http.HttpEntity; import org.apache.http.HttpStatus; import org.apache.http.entity.ContentType; @@ -14,6 +17,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.commons.alerting.model.Monitor.MonitorType; import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; @@ -29,13 +33,19 @@ import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; +import org.opensearch.securityanalytics.model.Rule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexMaxAggRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexMapping; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomProductDocument; import static org.opensearch.securityanalytics.TestHelpers.randomRule; +import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; public class DetectorRestApiIT extends SecurityAnalyticsRestTestCase { @@ -225,6 +235,91 @@ public void testCreatingADetectorWithCustomRules() throws IOException { Assert.assertEquals(6, noOfSigmaRuleMatches); } + public void testCreatingADetectorWithAggregationRules() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + String rule = productIndexAvgAggRule(); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "windows"), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdRuleId = responseBody.get("_id").toString(); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdRuleId)), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + Detector detector = randomDetectorWithInputs(List.of(input)); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + + createdRuleId = responseBody.get("_id").toString(); + int createdVersion = Integer.parseInt(responseBody.get("_version").toString()); + Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdRuleId); + Assert.assertTrue("incorrect version", createdVersion > 0); + Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdRuleId), createResponse.getHeader("Location")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("rule_topic_index")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("findings_index")); + Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("alert_index")); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdRuleId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + List monitorTypes = new ArrayList<>(); + + Map detectorAsMap = (Map) hit.getSourceAsMap().get("detector"); + + String bucketLevelMonitorId = ""; + + // Verify that doc level monitor is created + List monitorIds = (List) (detectorAsMap).get("monitor_id"); + + String firstMonitorId = monitorIds.get(0); + String firstMonitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + firstMonitorId))).get("monitor")).get("monitor_type"); + + if(MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(firstMonitorType)){ + bucketLevelMonitorId = firstMonitorId; + } + monitorTypes.add(firstMonitorType); + + String secondMonitorId = monitorIds.get(1); + String secondMonitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + secondMonitorId))).get("monitor")).get("monitor_type"); + monitorTypes.add(secondMonitorType); + if(MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(secondMonitorType)){ + bucketLevelMonitorId = secondMonitorId; + } + Assert.assertTrue(Arrays.asList(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), MonitorType.DOC_LEVEL_MONITOR.getValue()).containsAll(monitorTypes)); + + indexDoc(index, "1", randomProductDocument()); + + Response executeResponse = executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + // TODO - check findings + } public void testUpdateADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping()); @@ -286,6 +381,186 @@ public void testUpdateADetector() throws IOException { Assert.assertEquals(6, response.getHits().getTotalHits().value); } + public void testUpdateDetectorAddingNewAggregationRule() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String sumRuleId = createRule(sumAggregationTestRule()); + List detectorRules = List.of(new DetectorRule(sumRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(1, response.getHits().getTotalHits().value); + + // Test adding the new max monitor and updating the existing sum monitor + String maxRuleId = createRule(productIndexMaxAggRule()); + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(maxRuleId), new DetectorRule(sumRuleId)), + Collections.emptyList()); + Detector firstUpdatedDetector = randomDetectorWithInputs(List.of(newInput)); + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(firstUpdatedDetector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + + public void testUpdateDetectorDeletingExistingAggregationRule() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + List aggRuleIds = createAggregationRules(); + List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(2, response.getHits().getTotalHits().value); + + // Test deleting the aggregation rule + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(aggRuleIds.get(0))), + Collections.emptyList()); + Detector firstUpdatedDetector = randomDetectorWithInputs(List.of(newInput)); + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(firstUpdatedDetector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + + public void testUpdateDetectorWithAggregationAndDocLevelRules() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + List aggRuleIds = createAggregationRules(); + List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(2, response.getHits().getTotalHits().value); + + String maxRuleId = createRule(productIndexMaxAggRule()); + + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(aggRuleIds.get(0)), new DetectorRule(maxRuleId)), + Collections.emptyList()); + + detector = randomDetectorWithInputs(List.of(newInput)); + createResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(createResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + @SuppressWarnings("unchecked") public void testDeletingADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping()); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java index 4424fff14..bc395951e 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/RuleRestApiIT.java @@ -11,8 +11,17 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentParser.Token; +import org.opensearch.common.xcontent.XContentParserUtils; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.alerting.aggregation.bucketselectorext.BucketSelectorExtAggregationBuilder; import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; @@ -27,8 +36,11 @@ import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; +import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; +import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.countAggregationTestRule; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; import static org.opensearch.securityanalytics.TestHelpers.randomEditedRule; @@ -90,6 +102,46 @@ public void testCreatingARule() throws IOException { Assert.assertEquals(0, hits.size()); } + public void testCreatingAggregationRule() throws SigmaError, IOException { + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "windows"), + new StringEntity(countAggregationTestRule()), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + int createdVersion = Integer.parseInt(responseBody.get("_version").toString()); + Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdId); + Assert.assertTrue("incorrect version", createdVersion > 0); + Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.RULE_BASE_URI, createdId), createResponse.getHeader("Location")); + + String index = Rule.CUSTOM_RULES_INDEX; + String request = "{\n" + + " \"query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"rule\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " { \"match\": {\"rule.category\": \"windows\"}}\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + List hits = executeSearch(index, request); + + XContentParser xcp = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, hits.get(0).getSourceAsString()); + Rule result = Rule.docParse(xcp, null, null); + + Assert.assertEquals(1, result.getAggregationQueries().size()); + String expected = "{\"aggQuery\":\"{\\\"result_agg\\\":{\\\"terms\\\":{\\\"field\\\":\\\"_index\\\"}}}\",\"bucketTriggerQuery\":\"{\\\"buckets_path\\\":{\\\"_cnt\\\":\\\"_cnt\\\"},\\\"parent_bucket_path\\\":\\\"result_agg\\\",\\\"script\\\":{\\\"source\\\":\\\"params._cnt > 1.0\\\",\\\"lang\\\":\\\"painless\\\"}}\"}"; + Assert.assertEquals(expected, result.getAggregationQueries().get(0).getValue()); + } + @SuppressWarnings("unchecked") public void testCreatingARuleWithWrongSyntax() throws IOException { String rule = randomRuleWithErrors(); diff --git a/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java b/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java index b25b276d8..4d394ed36 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/aggregation/AggregationBackendTests.java @@ -42,7 +42,7 @@ public void testCountAggregation() throws SigmaError, IOException { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"_index\"}}}", aggQuery); + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"_index\"}}}", aggQuery); Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_cnt\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } @@ -73,7 +73,7 @@ public void testCountAggregationWithGroupBy() throws IOException, SigmaError { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"}}}", aggQuery); + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"}}}", aggQuery); Assert.assertEquals("{\"buckets_path\":{\"_cnt\":\"_cnt\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params._cnt > 1.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } @@ -104,7 +104,10 @@ public void testSumAggregationWithGroupBy() throws IOException, SigmaError { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"sum\":{\"field\":\"fieldA\"}}}}}", aggQuery); + + // inputs.query.aggregations -> Query + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"sum\":{\"field\":\"fieldA\"}}}}}", aggQuery); + // triggers.bucket_level_trigger.condition -> Condition Assert.assertEquals("{\"buckets_path\":{\"fieldA\":\"fieldA\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params.fieldA > 110.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } @@ -135,7 +138,7 @@ public void testMinAggregationWithGroupBy() throws IOException, SigmaError { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"min\":{\"field\":\"fieldA\"}}}}}", aggQuery); + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"min\":{\"field\":\"fieldA\"}}}}}", aggQuery); Assert.assertEquals("{\"buckets_path\":{\"fieldA\":\"fieldA\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params.fieldA > 110.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } @@ -166,7 +169,7 @@ public void testMaxAggregationWithGroupBy() throws IOException, SigmaError { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"max\":{\"field\":\"fieldA\"}}}}}", aggQuery); + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"max\":{\"field\":\"fieldA\"}}}}}", aggQuery); Assert.assertEquals("{\"buckets_path\":{\"fieldA\":\"fieldA\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params.fieldA > 110.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } @@ -197,7 +200,7 @@ public void testAvgAggregationWithGroupBy() throws IOException, SigmaError { String aggQuery = aggQueries.getAggQuery(); String bucketTriggerQuery = aggQueries.getBucketTriggerQuery(); - Assert.assertEquals("\"aggs\":{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"avg\":{\"field\":\"fieldA\"}}}}}", aggQuery); + Assert.assertEquals("{\"result_agg\":{\"terms\":{\"field\":\"fieldB\"},\"aggs\":{\"fieldA\":{\"avg\":{\"field\":\"fieldA\"}}}}}", aggQuery); Assert.assertEquals("{\"buckets_path\":{\"fieldA\":\"fieldA\"},\"parent_bucket_path\":\"result_agg\",\"script\":{\"source\":\"params.fieldA > 110.0\",\"lang\":\"painless\"}}", bucketTriggerQuery); } }