diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 7f5053d7..254511ae 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -421,7 +421,8 @@ public List> getSettings() { List> systemSetting = ImmutableList .of( - AnomalyDetectorSettings.MAX_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, AnomalyDetectorSettings.REQUEST_TIMEOUT, AnomalyDetectorSettings.DETECTION_INTERVAL, diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java index c6336fd6..0e844d0e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java @@ -22,6 +22,9 @@ public class CommonName { // index name for anomaly checkpoint of each model. One model one document. public static final String CHECKPOINT_INDEX_NAME = ".opendistro-anomaly-checkpoints"; + // The alias of the index in which to write AD result history + public static final String ANOMALY_RESULT_INDEX_ALIAS = ".opendistro-anomaly-results"; + // ====================================== // Format name // ====================================== @@ -55,4 +58,14 @@ public class CommonName { public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; public static final String MODELS = "models"; public static final String INIT_PROGRESS = "init_progress"; + + // Elastic mapping type + public static final String MAPPING_TYPE = "_doc"; + + // Used to fetch mapping + public static final String TYPE = "type"; + + public static final String KEYWORD_TYPE = "keyword"; + + public static final String IP_TYPE = "ip"; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetector.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetector.java index 50083b02..54d0b13b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetector.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetector.java @@ -15,11 +15,14 @@ package com.amazon.opendistroforelasticsearch.ad.model; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.CATEGORY_FIELD_LIMIT; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.DEFAULT_MULTI_ENTITY_SHINGLE; import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import java.io.IOException; +import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; @@ -76,6 +79,7 @@ public class AnomalyDetector implements Writeable, ToXContentObject { private static final String SHINGLE_SIZE_FIELD = "shingle_size"; private static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String UI_METADATA_FIELD = "ui_metadata"; + public static final String CATEGORY_FIELD = "category_field"; private final String detectorId; private final Long version; @@ -91,6 +95,7 @@ public class AnomalyDetector implements Writeable, ToXContentObject { private final Map uiMetadata; private final Integer schemaVersion; private final Instant lastUpdateTime; + private final List categoryFields; /** * Constructor function. @@ -109,6 +114,7 @@ public class AnomalyDetector implements Writeable, ToXContentObject { * @param uiMetadata metadata used by Kibana * @param schemaVersion anomaly detector index mapping version * @param lastUpdateTime detector's last update time + * @param categoryField a list of partition fields */ public AnomalyDetector( String detectorId, @@ -124,7 +130,8 @@ public AnomalyDetector( Integer shingleSize, Map uiMetadata, Integer schemaVersion, - Instant lastUpdateTime + Instant lastUpdateTime, + List categoryField ) { if (Strings.isBlank(name)) { throw new IllegalArgumentException("Detector name should be set"); @@ -141,6 +148,9 @@ public AnomalyDetector( if (shingleSize != null && shingleSize < 1) { throw new IllegalArgumentException("Shingle size must be a positive integer"); } + if (categoryField != null && categoryField.size() > CATEGORY_FIELD_LIMIT) { + throw new IllegalArgumentException("We only support filtering data by one categorical variable"); + } this.detectorId = detectorId; this.version = version; this.name = name; @@ -155,6 +165,44 @@ public AnomalyDetector( this.uiMetadata = uiMetadata; this.schemaVersion = schemaVersion; this.lastUpdateTime = lastUpdateTime; + this.categoryFields = categoryField; + } + + // TODO: remove after complete code merges. Created to not to touch too + // many places in one PR. + public AnomalyDetector( + String detectorId, + Long version, + String name, + String description, + String timeField, + List indices, + List features, + QueryBuilder filterQuery, + TimeConfiguration detectionInterval, + TimeConfiguration windowDelay, + Integer shingleSize, + Map uiMetadata, + Integer schemaVersion, + Instant lastUpdateTime + ) { + this( + detectorId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + detectionInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + null + ); } public AnomalyDetector(StreamInput input) throws IOException { @@ -188,6 +236,7 @@ public AnomalyDetector(StreamInput input) throws IOException { uiMetadata = input.readMap(); schemaVersion = input.readInt(); lastUpdateTime = input.readInstant(); + this.categoryFields = input.readStringList(); } public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @@ -210,6 +259,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeMap(uiMetadata); output.writeInt(schemaVersion); output.writeInstant(lastUpdateTime); + output.writeStringCollection(categoryFields); } @Override @@ -236,6 +286,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lastUpdateTime != null) { xContentBuilder.timeField(LAST_UPDATE_TIME_FIELD, LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); } + if (categoryFields != null) { + xContentBuilder.field(CATEGORY_FIELD, categoryFields.toArray()); + } return xContentBuilder.endObject(); } @@ -264,7 +317,7 @@ public static AnomalyDetector parse(XContentParser parser, String detectorId) th * @throws IOException IOException if content can't be parsed correctly */ public static AnomalyDetector parse(XContentParser parser, String detectorId, Long version) throws IOException { - return parse(parser, detectorId, version, null, null, null); + return parse(parser, detectorId, version, null, null); } /** @@ -275,7 +328,6 @@ public static AnomalyDetector parse(XContentParser parser, String detectorId, Lo * @param version detector document version * @param defaultDetectionInterval default detection interval * @param defaultDetectionWindowDelay default detection window delay - * @param defaultShingleSize default number of intervals in shingle * @return anomaly detector instance * @throws IOException IOException if content can't be parsed correctly */ @@ -284,8 +336,7 @@ public static AnomalyDetector parse( String detectorId, Long version, TimeValue defaultDetectionInterval, - TimeValue defaultDetectionWindowDelay, - Integer defaultShingleSize + TimeValue defaultDetectionWindowDelay ) throws IOException { String name = null; String description = null; @@ -298,12 +349,14 @@ public static AnomalyDetector parse( TimeConfiguration windowDelay = defaultDetectionWindowDelay == null ? null : new IntervalTimeConfiguration(defaultDetectionWindowDelay.getSeconds(), ChronoUnit.SECONDS); - Integer shingleSize = defaultShingleSize; + Integer shingleSize = null; List features = new ArrayList<>(); int schemaVersion = 0; Map uiMetadata = null; Instant lastUpdateTime = null; + List categoryField = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -359,6 +412,9 @@ public static AnomalyDetector parse( case LAST_UPDATE_TIME_FIELD: lastUpdateTime = ParseUtils.toInstant(parser); break; + case CATEGORY_FIELD: + categoryField = (List) parser.list(); + break; default: parser.skipChildren(); break; @@ -375,10 +431,11 @@ public static AnomalyDetector parse( filterQuery, detectionInterval, windowDelay, - shingleSize, + getShingleSize(shingleSize, categoryField), uiMetadata, schemaVersion, - lastUpdateTime + lastUpdateTime, + categoryField ); } @@ -483,7 +540,20 @@ public TimeConfiguration getWindowDelay() { } public Integer getShingleSize() { - return shingleSize == null ? DEFAULT_SHINGLE_SIZE : shingleSize; + return getShingleSize(shingleSize, categoryFields); + } + + /** + * If the given shingle size is null, return default based on the kind of detector; + * otherwise, return the given shingle size. + * @param customShingleSize Given shingle size + * @param categoryField Used to verify if this is a multi-entity or single-entity detector + * @return Shingle size + */ + private static Integer getShingleSize(Integer customShingleSize, List categoryField) { + return customShingleSize == null + ? (categoryField != null && categoryField.size() > 0 ? DEFAULT_MULTI_ENTITY_SHINGLE : DEFAULT_SHINGLE_SIZE) + : customShingleSize; } public Map getUiMetadata() { @@ -498,4 +568,19 @@ public Instant getLastUpdateTime() { return lastUpdateTime; } + public List getCategoryField() { + return this.categoryFields; + } + + public long getDetectorIntervalInMilliseconds() { + return ((IntervalTimeConfiguration) getDetectionInterval()).toDuration().toMillis(); + } + + public long getDetectorIntervalInSeconds() { + return getDetectorIntervalInMilliseconds() / 1000; + } + + public Duration getDetectionIntervalDuration() { + return ((IntervalTimeConfiguration) getDetectionInterval()).toDuration(); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestIndexAnomalyDetectorAction.java index 44292927..e5dcdeab 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -15,11 +15,11 @@ package com.amazon.opendistroforelasticsearch.ad.rest; -import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE; import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.DETECTION_INTERVAL; import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.DETECTION_WINDOW_DELAY; -import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_DETECTORS; import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS; import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.DETECTOR_ID; import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.IF_PRIMARY_TERM; @@ -65,15 +65,13 @@ public class RestIndexAnomalyDetectorAction extends BaseRestHandler { private static final String INDEX_ANOMALY_DETECTOR_ACTION = "index_anomaly_detector_action"; - private final AnomalyDetectionIndices anomalyDetectionIndices; private final Logger logger = LogManager.getLogger(RestIndexAnomalyDetectorAction.class); - private final ClusterService clusterService; - private final Settings settings; private volatile TimeValue requestTimeout; private volatile TimeValue detectionInterval; private volatile TimeValue detectionWindowDelay; - private volatile Integer maxAnomalyDetectors; + private volatile Integer maxSingleEntityDetectors; + private volatile Integer maxMultiEntityDetectors; private volatile Integer maxAnomalyFeatures; public RestIndexAnomalyDetectorAction( @@ -81,20 +79,23 @@ public RestIndexAnomalyDetectorAction( ClusterService clusterService, AnomalyDetectionIndices anomalyDetectionIndices ) { - this.settings = settings; - this.anomalyDetectionIndices = anomalyDetectionIndices; this.requestTimeout = REQUEST_TIMEOUT.get(settings); this.detectionInterval = DETECTION_INTERVAL.get(settings); this.detectionWindowDelay = DETECTION_WINDOW_DELAY.get(settings); - this.maxAnomalyDetectors = MAX_ANOMALY_DETECTORS.get(settings); + this.maxSingleEntityDetectors = MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxMultiEntityDetectors = MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings); this.maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); - this.clusterService = clusterService; // TODO: will add more cluster setting consumer later // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_INTERVAL, it -> detectionInterval = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_WINDOW_DELAY, it -> detectionWindowDelay = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_DETECTORS, it -> maxAnomalyDetectors = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, it -> maxSingleEntityDetectors = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_MULTI_ENTITY_ANOMALY_DETECTORS, it -> maxMultiEntityDetectors = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); } @@ -115,8 +116,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); // TODO: check detection interval < modelTTL - AnomalyDetector detector = AnomalyDetector - .parse(parser, detectorId, null, detectionInterval, detectionWindowDelay, DEFAULT_SHINGLE_SIZE); + AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId, null, detectionInterval, detectionWindowDelay); long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); @@ -131,7 +131,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli primaryTerm, refreshPolicy, detector, - method + method, + requestTimeout, + maxSingleEntityDetectors, + maxMultiEntityDetectors, + maxAnomalyFeatures ); return channel -> client diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index 942e9ae1..296fe6f0 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -21,6 +21,8 @@ import java.io.IOException; import java.time.Instant; import java.util.Arrays; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang.StringUtils; @@ -29,17 +31,21 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsResponse.FieldMappingMetadata; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; @@ -50,6 +56,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.builder.SearchSourceBuilder; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.transport.IndexAnomalyDetectorResponse; @@ -61,6 +68,12 @@ * PUT request is for updating anomaly detector. */ public class IndexAnomalyDetectorActionHandler { + public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create multi-entity anomaly detectors more than "; + public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = "Can't create single-entity anomaly detectors more than "; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document found in indices: "; + public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; + public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; + public static final String NOT_FOUND_ERR_MSG = "Cannot found the categorical field %s"; private final AnomalyDetectionIndices anomalyDetectionIndices; private final String detectorId; @@ -72,19 +85,18 @@ public class IndexAnomalyDetectorActionHandler { private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorActionHandler.class); private final TimeValue requestTimeout; - private final Integer maxAnomalyDetectors; + private final Integer maxSingleEntityAnomalyDetectors; + private final Integer maxMultiEntityAnomalyDetectors; private final Integer maxAnomalyFeatures; private final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); private final RestRequest.Method method; private final Client client; private final NamedXContentRegistry xContentRegistry; - private final Settings settings; private final ActionListener listener; /** * Constructor function. * - * @param settings ES settings * @param clusterService ClusterService * @param client ES node client that executes actions on the local node * @param listener ES channel used to construct bytes / builder based outputs, and send responses @@ -95,13 +107,13 @@ public class IndexAnomalyDetectorActionHandler { * @param refreshPolicy refresh policy * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration - * @param maxAnomalyDetectors max anomaly detector allowed + * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed + * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed * @param maxAnomalyFeatures max features allowed per detector * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser */ public IndexAnomalyDetectorActionHandler( - Settings settings, ClusterService clusterService, Client client, ActionListener listener, @@ -112,12 +124,12 @@ public IndexAnomalyDetectorActionHandler( WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector anomalyDetector, TimeValue requestTimeout, - Integer maxAnomalyDetectors, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, Integer maxAnomalyFeatures, RestRequest.Method method, NamedXContentRegistry xContentRegistry ) { - this.settings = settings; this.clusterService = clusterService; this.client = client; this.anomalyDetectionIndices = anomalyDetectionIndices; @@ -128,7 +140,8 @@ public IndexAnomalyDetectorActionHandler( this.refreshPolicy = refreshPolicy; this.anomalyDetector = anomalyDetector; this.requestTimeout = requestTimeout; - this.maxAnomalyDetectors = maxAnomalyDetectors; + this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; + this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; this.maxAnomalyFeatures = maxAnomalyFeatures; this.method = method; this.xContentRegistry = xContentRegistry; @@ -191,29 +204,45 @@ private void onGetAnomalyDetectorResponse(GetResponse response) throws IOExcepti return; } - searchAdInputIndices(detectorId); + validateCategoricalField(detectorId); } private void createAnomalyDetector() { try { - QueryBuilder query = QueryBuilders.matchAllQuery(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + List categoricalFields = anomalyDetector.getCategoryField(); + if (categoricalFields != null && categoricalFields.size() > 0) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); - SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - client - .search( - searchRequest, - ActionListener.wrap(response -> onSearchAdResponse(response), exception -> listener.onFailure(exception)) - ); + SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener.wrap(response -> onSearchMultiEntityAdResponse(response), exception -> listener.onFailure(exception)) + ); + } else { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap(response -> onSearchSingleEntityAdResponse(response), exception -> listener.onFailure(exception)) + ); + } } catch (Exception e) { listener.onFailure(e); } } - private void onSearchAdResponse(SearchResponse response) throws IOException { - if (response.getHits().getTotalHits().value >= maxAnomalyDetectors) { - String errorMsg = "Can't create anomaly detector more than " + maxAnomalyDetectors; + private void onSearchSingleEntityAdResponse(SearchResponse response) throws IOException { + if (response.getHits().getTotalHits().value >= maxSingleEntityAnomalyDetectors) { + String errorMsg = EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG + maxSingleEntityAnomalyDetectors; logger.error(errorMsg); listener.onFailure(new IllegalArgumentException(errorMsg)); } else { @@ -221,6 +250,82 @@ private void onSearchAdResponse(SearchResponse response) throws IOException { } } + private void onSearchMultiEntityAdResponse(SearchResponse response) throws IOException { + if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { + String errorMsg = EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG + maxMultiEntityAnomalyDetectors; + logger.error(errorMsg); + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateCategoricalField(null); + } + } + + @SuppressWarnings("unchecked") + private void validateCategoricalField(String detectorId) { + List categoryField = anomalyDetector.getCategoryField(); + + if (categoryField == null) { + searchAdInputIndices(detectorId); + return; + } + + // we only support one categorical field + // If there is more than 1 field or none, AnomalyDetector's constructor + // throws IllegalArgumentException before reaching this line + if (categoryField.size() != 1) { + listener.onFailure(new IllegalArgumentException(ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG)); + return; + } + + String categoryField0 = categoryField.get(0); + + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + // example getMappingsResponse: + // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', + // source=org.elasticsearch.common.bytes.BytesArray@7ba87dbd}}}}} + boolean foundField = false; + Map>> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map> mappingsByType : mappingsByIndex.values()) { + for (Map mappingsByField : mappingsByType.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + + if (fieldMetadata != null) { + Object metadata = fieldMetadata.sourceAsMap().get(categoryField0); + if (metadata != null && metadata instanceof Map) { + foundField = true; + Map metadataMap = (Map) metadata; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { + listener.onFailure(new IllegalArgumentException(CATEGORICAL_FIELD_TYPE_ERR_MSG)); + return; + } + } + } + } + } + } + + if (foundField == false) { + listener.onFailure(new IllegalArgumentException(String.format(NOT_FOUND_ERR_MSG, categoryField0))); + return; + } + + searchAdInputIndices(detectorId); + }, error -> { + String message = String.format("Fail to get the index mapping of %s", anomalyDetector.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + + client.execute(GetFieldMappingsAction.INSTANCE, getMappingsRequest, mappingsListener); + } + private void searchAdInputIndices(String detectorId) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(QueryBuilders.matchAllQuery()) @@ -242,8 +347,7 @@ private void searchAdInputIndices(String detectorId) { private void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId) throws IOException { if (response.getHits().getTotalHits().value == 0) { - String errorMsg = "Can't create anomaly detector as no document found in indices: " - + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); + String errorMsg = NO_DOCS_IN_USER_INDEX_MSG + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); logger.error(errorMsg); listener.onFailure(new IllegalArgumentException(errorMsg)); } else { @@ -302,7 +406,8 @@ private void indexAnomalyDetector(String detectorId) throws IOException { anomalyDetector.getShingleSize(), anomalyDetector.getUiMetadata(), anomalyDetector.getSchemaVersion(), - Instant.now() + Instant.now(), + anomalyDetector.getCategoryField() ); IndexRequest indexRequest = new IndexRequest(ANOMALY_DETECTORS_INDEX) .setRefreshPolicy(refreshPolicy) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java index aa617b47..b3eaca7b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java @@ -27,9 +27,17 @@ public final class AnomalyDetectorSettings { private AnomalyDetectorSettings() {} - public static final Setting MAX_ANOMALY_DETECTORS = Setting + public static final Setting MAX_SINGLE_ENTITY_ANOMALY_DETECTORS = Setting .intSetting("opendistro.anomaly_detection.max_anomaly_detectors", 1000, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting MAX_MULTI_ENTITY_ANOMALY_DETECTORS = Setting + .intSetting( + "opendistro.anomaly_detection.max_multi_entity_anomaly_detectors", + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting MAX_ANOMALY_FEATURES = Setting .intSetting("opendistro.anomaly_detection.max_anomaly_features", 5, Setting.Property.NodeScope, Setting.Property.Dynamic); @@ -214,4 +222,11 @@ private AnomalyDetectorSettings() {} // Thread pool public static final int AD_THEAD_POOL_QUEUE_SIZE = 1000; + + // Multi-entity detector model setting: + // TODO (kaituo): change to 4 + public static final int DEFAULT_MULTI_ENTITY_SHINGLE = 1; + + // how many categorical fields we support + public static final int CATEGORY_FIELD_LIMIT = 1; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorRequest.java index d1c98214..6bbdc293 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -22,6 +22,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestRequest; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; @@ -34,6 +35,10 @@ public class IndexAnomalyDetectorRequest extends ActionRequest { private WriteRequest.RefreshPolicy refreshPolicy; private AnomalyDetector detector; private RestRequest.Method method; + private TimeValue requestTimeout; + private Integer maxSingleEntityAnomalyDetectors; + private Integer maxMultiEntityAnomalyDetectors; + private Integer maxAnomalyFeatures; public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -43,6 +48,10 @@ public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { refreshPolicy = in.readEnum(WriteRequest.RefreshPolicy.class); detector = new AnomalyDetector(in); method = in.readEnum(RestRequest.Method.class); + requestTimeout = in.readTimeValue(); + maxSingleEntityAnomalyDetectors = in.readInt(); + maxMultiEntityAnomalyDetectors = in.readInt(); + maxAnomalyFeatures = in.readInt(); } public IndexAnomalyDetectorRequest( @@ -51,7 +60,11 @@ public IndexAnomalyDetectorRequest( long primaryTerm, WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector detector, - RestRequest.Method method + RestRequest.Method method, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures ) { super(); this.detectorID = detectorID; @@ -60,6 +73,10 @@ public IndexAnomalyDetectorRequest( this.refreshPolicy = refreshPolicy; this.detector = detector; this.method = method; + this.requestTimeout = requestTimeout; + this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; + this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; + this.maxAnomalyFeatures = maxAnomalyFeatures; } public String getDetectorID() { @@ -86,6 +103,22 @@ public RestRequest.Method getMethod() { return method; } + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxSingleEntityAnomalyDetectors() { + return maxSingleEntityAnomalyDetectors; + } + + public Integer getMaxMultiEntityAnomalyDetectors() { + return maxMultiEntityAnomalyDetectors; + } + + public Integer getMaxAnomalyFeatures() { + return maxAnomalyFeatures; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -95,6 +128,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnum(refreshPolicy); detector.writeTo(out); out.writeEnum(method); + out.writeTimeValue(requestTimeout); + out.writeInt(maxSingleEntityAnomalyDetectors); + out.writeInt(maxMultiEntityAnomalyDetectors); + out.writeInt(maxAnomalyFeatures); } @Override diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportAction.java index fae053f6..4e0648ec 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -15,10 +15,8 @@ package com.amazon.opendistroforelasticsearch.ad.transport; -import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_DETECTORS; -import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; -import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; - +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; @@ -38,9 +36,8 @@ import com.amazon.opendistroforelasticsearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { - + private static final Logger LOG = LogManager.getLogger(IndexAnomalyDetectorTransportAction.class); private final Client client; - private final Settings settings; private final AnomalyDetectionIndices anomalyDetectionIndices; private final ClusterService clusterService; private final NamedXContentRegistry xContentRegistry; @@ -58,7 +55,6 @@ public IndexAnomalyDetectorTransportAction( super(IndexAnomalyDetectorAction.NAME, transportService, actionFilters, IndexAnomalyDetectorRequest::new); this.client = client; this.clusterService = clusterService; - this.settings = settings; this.anomalyDetectionIndices = anomalyDetectionIndices; this.xContentRegistry = xContentRegistry; } @@ -71,12 +67,12 @@ protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionL WriteRequest.RefreshPolicy refreshPolicy = request.getRefreshPolicy(); AnomalyDetector detector = request.getDetector(); RestRequest.Method method = request.getMethod(); - TimeValue requestTimeout = REQUEST_TIMEOUT.get(settings); - Integer maxAnomalyDetectors = MAX_ANOMALY_DETECTORS.get(settings); - Integer maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + TimeValue requestTimeout = request.getRequestTimeout(); + Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); + Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); + Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); IndexAnomalyDetectorActionHandler indexAnomalyDetectorActionHandler = new IndexAnomalyDetectorActionHandler( - settings, clusterService, client, listener, @@ -87,7 +83,8 @@ protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionL refreshPolicy, detector, requestTimeout, - maxAnomalyDetectors, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, method, xContentRegistry @@ -95,7 +92,7 @@ protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionL try { indexAnomalyDetectorActionHandler.start(); } catch (Exception e) { - logger.error(e); + LOG.error(e); } } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java index f15833c8..ded7bffb 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java @@ -18,7 +18,11 @@ import static org.hamcrest.Matchers.containsString; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -32,7 +36,15 @@ import org.apache.logging.log4j.util.StackLocatorUtil; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestRequest.Method; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; @@ -42,6 +54,11 @@ import test.com.amazon.opendistroforelasticsearch.ad.util.FakeNode; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorInternalState; + public class AbstractADTest extends ESTestCase { protected static final Logger LOG = (Logger) LogManager.getLogger(AbstractADTest.class); @@ -216,4 +233,89 @@ public void assertException( Exception e = expectThrows(exceptionType, () -> listener.actionGet()); assertThat(e.getMessage(), containsString(msg)); } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + List entries = searchModule.getNamedXContents(); + entries + .addAll( + Arrays + .asList( + AnomalyDetector.XCONTENT_REGISTRY, + AnomalyResult.XCONTENT_REGISTRY, + DetectorInternalState.XCONTENT_REGISTRY, + AnomalyDetectorJob.XCONTENT_REGISTRY + ) + ); + return new NamedXContentRegistry(entries); + } + + protected RestRequest createRestRequest(Method method) { + return RestRequest.request(xContentRegistry(), new HttpRequest() { + + @Override + public Method method() { + return method; + } + + @Override + public String uri() { + return "/"; + } + + @Override + public BytesReference content() { + // TODO Auto-generated method stub + return null; + } + + @Override + public Map> getHeaders() { + return new HashMap<>(); + } + + @Override + public List strictCookies() { + // TODO Auto-generated method stub + return null; + } + + @Override + public HttpVersion protocolVersion() { + return HttpRequest.HttpVersion.HTTP_1_1; + } + + @Override + public HttpRequest removeHeader(String header) { + // TODO Auto-generated method stub + return null; + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + // TODO Auto-generated method stub + return null; + } + + @Override + public Exception getInboundException() { + // TODO Auto-generated method stub + return null; + } + + @Override + public void release() { + // TODO Auto-generated method stub + + } + + @Override + public HttpRequest releaseAndCopy() { + // TODO Auto-generated method stub + return null; + } + + }, null); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java index b6786968..868e4278 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java @@ -106,7 +106,8 @@ protected AnomalyDetector createAnomalyDetector(AnomalyDetector detector, Boolea detector.getShingleSize(), detector.getUiMetadata(), detector.getSchemaVersion(), - detector.getLastUpdateTime() + detector.getLastUpdateTime(), + null ); } @@ -176,7 +177,8 @@ public ToXContentObject[] getAnomalyDetector(String detectorId, BasicHeader head detector.getShingleSize(), detector.getUiMetadata(), detector.getSchemaVersion(), - detector.getLastUpdateTime() + detector.getLastUpdateTime(), + null ), detectorJob }; } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java index d7c394b9..6f7a66b7 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java @@ -30,6 +30,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -47,6 +48,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.mapping.get.GetFieldMappingsResponse.FieldMappingMetadata; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; @@ -66,6 +68,7 @@ import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.Priority; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; @@ -96,6 +99,7 @@ import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.threadpool.ThreadPool; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorExecutionInput; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; @@ -177,6 +181,11 @@ public static AnomalyDetector randomAnomalyDetector(Map uiMetada return randomAnomalyDetector(ImmutableList.of(randomFeature()), uiMetadata, lastUpdateTime); } + public static AnomalyDetector randomAnomalyDetector(Map uiMetadata, Instant lastUpdateTime, boolean featureEnabled) + throws IOException { + return randomAnomalyDetector(ImmutableList.of(randomFeature(featureEnabled)), uiMetadata, lastUpdateTime); + } + public static AnomalyDetector randomAnomalyDetector(List features, Map uiMetadata, Instant lastUpdateTime) throws IOException { return new AnomalyDetector( @@ -193,7 +202,29 @@ public static AnomalyDetector randomAnomalyDetector(List features, Map< randomIntBetween(1, 2000), uiMetadata, randomInt(), - lastUpdateTime + lastUpdateTime, + null + ); + } + + public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields(String detectorId, List categoryFields) + throws IOException { + return new AnomalyDetector( + detectorId, + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase()), + ImmutableList.of(randomFeature()), + randomQuery(), + randomIntervalTimeConfiguration(), + randomIntervalTimeConfiguration(), + randomIntBetween(1, 2000), + null, + randomInt(), + Instant.now(), + categoryFields ); } @@ -212,7 +243,8 @@ public static AnomalyDetector randomAnomalyDetector(List features) thro randomIntBetween(1, 2000), null, randomInt(), - Instant.now() + Instant.now(), + null ); } @@ -231,7 +263,8 @@ public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOE randomIntBetween(1, 2000), null, randomInt(), - Instant.now().truncatedTo(ChronoUnit.SECONDS) + Instant.now().truncatedTo(ChronoUnit.SECONDS), + null ); } @@ -250,7 +283,8 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio randomIntBetween(1, 2000), null, randomInt(), - Instant.now().truncatedTo(ChronoUnit.SECONDS) + Instant.now().truncatedTo(ChronoUnit.SECONDS), + null ); } @@ -317,6 +351,21 @@ public static Feature randomFeature(String featureName, String aggregationName) return new Feature(randomAlphaOfLength(5), featureName, ESRestTestCase.randomBoolean(), testAggregation); } + public static Feature randomFeature(boolean enabled) { + return randomFeature(randomAlphaOfLength(5), randomAlphaOfLength(5), enabled); + } + + public static Feature randomFeature(String featureName, String aggregationName, boolean enabled) { + AggregationBuilder testAggregation = null; + try { + testAggregation = randomAggregation(aggregationName); + } catch (IOException e) { + logger.error("Fail to generate test aggregation"); + throw new RuntimeException(); + } + return new Feature(randomAlphaOfLength(5), featureName, enabled, testAggregation); + } + public static void assertFailWith(Class clazz, Callable callable) throws Exception { assertFailWith(clazz, null, callable); } @@ -559,4 +608,18 @@ public static DetectorInternalState randomDetectState(Instant lastUpdateTime) { public static DetectorInternalState randomDetectState(String error, Instant lastUpdateTime) { return new DetectorInternalState.Builder().lastUpdateTime(lastUpdateTime).error(error).build(); } + + public static Map>> createFieldMappings( + String index, + String fieldName, + String fieldType + ) throws IOException { + Map>> mappings = new HashMap<>(); + FieldMappingMetadata fieldMappingMetadata = new FieldMappingMetadata( + fieldName, + new BytesArray("{\"" + fieldName + "\":{\"type\":\"" + fieldType + "\"}}") + ); + mappings.put(index, Collections.singletonMap(CommonName.MAPPING_TYPE, Collections.singletonMap(fieldName, fieldMappingMetadata))); + return mappings; + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorTests.java index da3ddcd9..2e20b3ec 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorTests.java @@ -89,8 +89,7 @@ public void testParseAnomalyDetectorWithoutOptionalParams() throws IOException { + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"schema_version\":-1203962153,\"ui_metadata\":" + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; - AnomalyDetector parsedDetector = AnomalyDetector - .parse(TestHelpers.parser(detectorString), "id", 1L, null, null, AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); assertEquals((long) parsedDetector.getShingleSize(), (long) AnomalyDetectorSettings.DEFAULT_SHINGLE_SIZE); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorActionTests.java index cf21c859..a1905e13 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.junit.Assert; @@ -52,7 +53,12 @@ public void testIndexRequest() throws Exception { 5678, WriteRequest.RefreshPolicy.NONE, detector, - RestRequest.Method.PUT + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5 + ); request.writeTo(out); StreamInput input = out.bytes().streamInput(); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index 44b23dd5..c2e4b644 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESIntegTestCase; @@ -58,7 +59,11 @@ public void setUp() throws Exception { 7890, WriteRequest.RefreshPolicy.IMMEDIATE, mock(AnomalyDetector.class), - RestRequest.Method.PUT + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5 ); response = new ActionListener() { @Override diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/SearchAnomalyDetectorActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/SearchAnomalyDetectorActionTests.java index d547d3d7..c49802b5 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/SearchAnomalyDetectorActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/SearchAnomalyDetectorActionTests.java @@ -59,7 +59,7 @@ public void onFailure(Exception e) { @Test public void testSearchResponse() { // Will call response.onResponse as Index exists - Settings indexSettings = Settings.builder().put("number_of_shards", 5).put("number_of_replicas", 1).build(); + Settings indexSettings = Settings.builder().put("index.number_of_shards", 5).put("index.number_of_replicas", 1).build(); CreateIndexRequest indexRequest = new CreateIndexRequest("my-test-index", indexSettings); client().admin().indices().create(indexRequest).actionGet(); SearchRequest searchRequest = new SearchRequest("my-test-index"); diff --git a/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java new file mode 100644 index 00000000..d17f33b0 --- /dev/null +++ b/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -0,0 +1,526 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.elasticsearch.action.admin.indices.mapping.get; + +import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.get.GetAction; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; + +import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.TestHelpers; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import com.amazon.opendistroforelasticsearch.ad.transport.IndexAnomalyDetectorResponse; + +/** + * + * we need to put the test in the same package of GetFieldMappingsResponse + * (org.elasticsearch.action.admin.indices.mapping.get) since its constructor is + * package private + * + */ +public class IndexAnomalyDetectorActionHandlerTests extends AbstractADTest { + static ThreadPool threadPool; + private String TEXT_FIELD_TYPE = "text"; + private IndexAnomalyDetectorActionHandler handler; + private ClusterService clusterService; + private NodeClient clientMock; + private ActionListener channel; + private AnomalyDetectionIndices anomalyDetectionIndices; + private String detectorId; + private Long seqNo; + private Long primaryTerm; + private AnomalyDetector detector; + private WriteRequest.RefreshPolicy refreshPolicy; + private TimeValue requestTimeout; + private Integer maxSingleEntityAnomalyDetectors; + private Integer maxMultiEntityAnomalyDetectors; + private Integer maxAnomalyFeatures; + private Settings settings; + private RestRequest.Method method; + + /** + * Mockito does not allow mock final methods. Make my own delegates and mock them. + * + */ + class NodeClientDelegate extends NodeClient { + + NodeClientDelegate(Settings settings, ThreadPool threadPool) { + super(settings, threadPool); + } + + public void execute2( + ActionType action, + Request request, + ActionListener listener + ) { + super.execute(action, request, listener); + } + + } + + @BeforeClass + public static void beforeClass() { + threadPool = new TestThreadPool("IndexAnomalyDetectorJobActionHandlerTests"); + } + + @AfterClass + public static void afterClass() { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + threadPool = null; + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + settings = Settings.EMPTY; + clusterService = mock(ClusterService.class); + clientMock = spy(new NodeClient(settings, null)); + + channel = mock(ActionListener.class); + + // final RestRequest restRequest = createRestRequest(Method.POST); + + // when(channel.request()).thenReturn(restRequest); + // when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder()); + // when(channel.detailedErrorsEnabled()).thenReturn(true); + + anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); + when(anomalyDetectionIndices.doesAnomalyDetectorIndexExist()).thenReturn(true); + + detectorId = "123"; + seqNo = 0L; + primaryTerm = 0L; + + WriteRequest.RefreshPolicy refreshPolicy = WriteRequest.RefreshPolicy.IMMEDIATE; + + String field = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + requestTimeout = new TimeValue(1000L); + + maxSingleEntityAnomalyDetectors = 1000; + + maxMultiEntityAnomalyDetectors = 10; + + maxAnomalyFeatures = 5; + + method = RestRequest.Method.POST; + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry() + ); + } + + private SearchHits createSearchHits(int totalHits) { + List hitList = new ArrayList<>(); + IntStream.range(0, totalHits).forEach(i -> hitList.add(new SearchHit(i))); + SearchHit[] hitArray = new SearchHit[hitList.size()]; + return new SearchHits(hitList.toArray(hitArray), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0F); + } + + public void testTwoCategoricalFields() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a", "b")) + ); + ; + } + + @SuppressWarnings("unchecked") + public void testNoCategoricalField() throws IOException { + SearchResponse mockResponse = mock(SearchResponse.class); + int totalHits = 1001; + when(mockResponse.getHits()).thenReturn(createSearchHits(totalHits)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(mockResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + // no categorical feature + TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true), + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry() + ); + + handler.start(); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG)); + } + + @SuppressWarnings("unchecked") + public void testTextField() throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + listener.onResponse((Response) detectorResponse); + } else { + // we need to put the test in the same package of GetFieldMappingsResponse since its constructor is package private + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, TEXT_FIELD_TYPE) + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + client, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry() + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof Exception); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + } + + @SuppressWarnings("unchecked") + private void testValidTypeTepmlate(String filedTypeName) throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(createSearchHits(userIndexHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + assertTrue(request instanceof SearchRequest); + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(ANOMALY_DETECTORS_INDEX)) { + listener.onResponse((Response) detectorResponse); + } else { + listener.onResponse((Response) userIndexResponse); + } + } else { + + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, filedTypeName) + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + + NodeClient clientSpy = spy(client); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry() + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); + } + + public void testIpField() throws IOException { + testValidTypeTepmlate(CommonName.IP_TYPE); + } + + public void testKeywordField() throws IOException { + testValidTypeTepmlate(CommonName.KEYWORD_TYPE); + } + + @SuppressWarnings("unchecked") + private void testUpdateTepmlate(String fieldTypeName) throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + GetResponse getDetectorResponse = mock(GetResponse.class); + when(getDetectorResponse.isExists()).thenReturn(true); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(createSearchHits(userIndexHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + assertTrue(request instanceof SearchRequest); + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(ANOMALY_DETECTORS_INDEX)) { + listener.onResponse((Response) detectorResponse); + } else { + listener.onResponse((Response) userIndexResponse); + } + } else if (action.equals(GetAction.INSTANCE)) { + assertTrue(request instanceof GetRequest); + listener.onResponse((Response) getDetectorResponse); + } else { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, fieldTypeName) + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + + NodeClient clientSpy = spy(client); + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry() + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); + } else { + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + } + + } + + public void testUpdateIpField() throws IOException { + testUpdateTepmlate(CommonName.IP_TYPE); + } + + public void testUpdateKeywordField() throws IOException { + testUpdateTepmlate(CommonName.KEYWORD_TYPE); + } + + public void testUpdateTextField() throws IOException { + testUpdateTepmlate(TEXT_FIELD_TYPE); + } + + @SuppressWarnings("unchecked") + public void testMoreThanTenMultiEntityDetectors() throws IOException { + SearchResponse mockResponse = mock(SearchResponse.class); + + int totalHits = 11; + + when(mockResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(mockResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + handler.start(); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, times(1)).search(any(SearchRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + } +}