diff --git a/build.gradle b/build.gradle index a6e89166..a3bde1cb 100644 --- a/build.gradle +++ b/build.gradle @@ -207,23 +207,17 @@ List jacocoExclusions = [ 'com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException', 'com.amazon.opendistroforelasticsearch.ad.util.ClientUtil', - 'com.amazon.opendistroforelasticsearch.ad.ml.*', - 'com.amazon.opendistroforelasticsearch.ad.feature.*', - 'com.amazon.opendistroforelasticsearch.ad.dataprocessor.*', - 'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner', - 'com.amazon.opendistroforelasticsearch.ad.resthandler.RestGetAnomalyResultAction', - 'com.amazon.opendistroforelasticsearch.ad.metrics.MetricFactory', - 'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices', - 'com.amazon.opendistroforelasticsearch.ad.transport.ForwardAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.ForwardTransportAction', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorAction', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorRequest', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorResponse', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorTransportAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest', 'com.amazon.opendistroforelasticsearch.ad.transport.DeleteDetectorAction', - 'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils' + 'com.amazon.opendistroforelasticsearch.ad.transport.CronTransportAction', + 'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest', + 'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction', + 'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner', + 'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices', + 'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils', ] jacocoTestCoverageVerification { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java index e59bf605..82f6d502 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java @@ -28,6 +28,7 @@ import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultResponse; import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction; import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.JobExecutionContext; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.LockModel; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.ScheduledJobParameter; @@ -39,7 +40,9 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.settings.Settings; @@ -71,6 +74,7 @@ public class AnomalyDetectorJobRunner implements ScheduledJobRunner { private Settings settings; private int maxRetryForEndRunException; private Client client; + private ClientUtil clientUtil; private ThreadPool threadPool; private AnomalyResultHandler anomalyResultHandler; private ConcurrentHashMap detectorEndRunExceptionCount; @@ -97,6 +101,10 @@ public void setClient(Client client) { this.client = client; } + public void setClientUtil(ClientUtil clientUtil) { + this.clientUtil = clientUtil; + } + public void setThreadPool(ThreadPool threadPool) { this.threadPool = threadPool; } @@ -258,7 +266,7 @@ protected void handleAdException( ) { String detectorId = jobParameter.getName(); if (exception instanceof EndRunException) { - log.error("EndRunException happened when executed anomaly result action for " + detectorId, exception); + log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception); if (((EndRunException) exception).isEndNow()) { // Stop AD job if EndRunException shows we should end job now. @@ -349,9 +357,8 @@ private void stopAdJob(String detectorId) { try { GetRequest getRequest = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX).id(detectorId); - client.get(getRequest, ActionListener.wrap(response -> { + clientUtil.asyncRequest(getRequest, client::get, ActionListener.wrap(response -> { if (response.isExists()) { - String s = response.getSourceAsString(); try ( XContentParser parser = XContentType.JSON .xContent() @@ -374,14 +381,19 @@ private void stopAdJob(String detectorId) { .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)) .id(detectorId); - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - if (indexResponse != null - && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { - log.info("AD Job was disabled by JobRunner for " + detectorId); - } else { - log.warn("Failed to disable AD job for " + detectorId); - } - }, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception))); + clientUtil + .asyncRequest( + indexRequest, + client::index, + ActionListener.wrap(indexResponse -> { + if (indexResponse != null + && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { + log.info("AD Job was disabled by JobRunner for " + detectorId); + } else { + log.warn("Failed to disable AD job for " + detectorId); + } + }, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception)) + ); } } catch (IOException e) { log.error("JobRunner failed to stop detector job " + detectorId, e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 5f93e936..285ab608 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; import com.amazon.opendistroforelasticsearch.ad.rest.RestAnomalyDetectorJobAction; import com.amazon.opendistroforelasticsearch.ad.rest.RestDeleteAnomalyDetectorAction; import com.amazon.opendistroforelasticsearch.ad.rest.RestExecuteAnomalyDetectorAction; @@ -141,8 +142,9 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip private Client client; private ClusterService clusterService; private ThreadPool threadPool; - private IndexNameExpressionResolver indexNameExpressionResolver; private ADStats adStats; + private NamedXContentRegistry xContentRegistry; + private ClientUtil clientUtil; static { SpecialPermission.check(); @@ -163,7 +165,6 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - this.indexNameExpressionResolver = indexNameExpressionResolver; AnomalyResultHandler anomalyResultHandler = new AnomalyResultHandler( client, settings, @@ -174,11 +175,17 @@ public List getRestHandlers( ); AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); jobRunner.setClient(client); + jobRunner.setClientUtil(clientUtil); jobRunner.setThreadPool(threadPool); jobRunner.setAnomalyResultHandler(anomalyResultHandler); jobRunner.setSettings(settings); - RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(restController); + AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner(client, this.xContentRegistry); + RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction( + restController, + profileRunner, + ProfileName.getNames() + ); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction( settings, restController, @@ -237,10 +244,11 @@ public Collection createComponents( Settings settings = environment.settings(); Clock clock = Clock.systemUTC(); Throttler throttler = new Throttler(clock); - ClientUtil clientUtil = new ClientUtil(settings, client, throttler, threadPool); + this.clientUtil = new ClientUtil(settings, client, throttler, threadPool); IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService); anomalyDetectionIndices = new AnomalyDetectionIndices(client, clusterService, threadPool, settings, clientUtil); this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); @@ -272,7 +280,8 @@ public Collection createComponents( HybridThresholdingModel.class, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.SHINGLE_SIZE ); HashRing hashRing = new HashRing(clusterService, clock, settings); @@ -389,7 +398,7 @@ public List> getSettings() { @Override public List getNamedXContent() { - return ImmutableList.of(AnomalyDetector.XCONTENT_REGISTRY, ADMetaData.XCONTENT_REGISTRY); + return ImmutableList.of(AnomalyDetector.XCONTENT_REGISTRY, ADMetaData.XCONTENT_REGISTRY, AnomalyResult.XCONTENT_REGISTRY); } @Override diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java new file mode 100644 index 00000000..1cf75cda --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java @@ -0,0 +1,287 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad; + +import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; +import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortOrder; + +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorProfile; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorState; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; +import com.amazon.opendistroforelasticsearch.ad.util.MultiResponsesDelegateActionListener; + +public class AnomalyDetectorProfileRunner { + private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); + private Client client; + private NamedXContentRegistry xContentRegistry; + static String FAIL_TO_FIND_DETECTOR_MSG = "Fail to find detector with id: "; + static String FAIL_TO_GET_PROFILE_MSG = "Fail to get profile for detector "; + + public AnomalyDetectorProfileRunner(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + public void profile(String detectorId, ActionListener listener, Set profiles) { + + if (profiles.isEmpty()) { + listener.onFailure(new RuntimeException("Unsupported profile types.")); + return; + } + + MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener( + listener, + profiles.size(), + "Fail to fetch profile for " + detectorId + ); + + prepareProfile(detectorId, delegateListener, profiles); + } + + private void prepareProfile( + String detectorId, + MultiResponsesDelegateActionListener listener, + Set profiles + ) { + GetRequest getRequest = new GetRequest(ANOMALY_DETECTOR_JOB_INDEX, detectorId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + if (profiles.contains(ProfileName.STATE)) { + profileState(detectorId, enabledTimeMs, listener, job.isEnabled()); + } + if (profiles.contains(ProfileName.ERROR)) { + profileError(detectorId, enabledTimeMs, listener); + } + } catch (IOException | XContentParseException | NullPointerException e) { + logger.error(e); + listener.failImmediately(FAIL_TO_GET_PROFILE_MSG, e); + } + } else { + GetRequest getDetectorRequest = new GetRequest(ANOMALY_DETECTORS_INDEX, detectorId); + client.get(getDetectorRequest, onGetDetectorResponse(listener, detectorId, profiles)); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(exception.getMessage()); + GetRequest getDetectorRequest = new GetRequest(ANOMALY_DETECTORS_INDEX, detectorId); + client.get(getDetectorRequest, onGetDetectorResponse(listener, detectorId, profiles)); + } else { + logger.error(FAIL_TO_GET_PROFILE_MSG + detectorId); + listener.onFailure(exception); + } + })); + } + + private ActionListener onGetDetectorResponse( + MultiResponsesDelegateActionListener listener, + String detectorId, + Set profiles + ) { + return ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + DetectorProfile profile = new DetectorProfile(); + if (profiles.contains(ProfileName.STATE)) { + profile.setState(DetectorState.DISABLED); + } + listener.respondImmediately(profile); + } else { + listener.failImmediately(FAIL_TO_FIND_DETECTOR_MSG + detectorId); + } + }, exception -> { listener.failImmediately(FAIL_TO_FIND_DETECTOR_MSG + detectorId, exception); }); + } + + /** + * We expect three kinds of states: + * -Disabled: if get ad job api says the job is disabled; + * -Init: if anomaly score after the last update time of the detector is larger than 0 + * -Running: if neither of the above applies and no exceptions. + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @param listener listener to process the returned state or exception + * @param enabled whether the detector job is enabled or not + */ + private void profileState( + String detectorId, + long enabledTime, + MultiResponsesDelegateActionListener listener, + boolean enabled + ) { + if (enabled) { + SearchRequest searchLatestResult = createInittedEverRequest(detectorId, enabledTime); + client.search(searchLatestResult, onInittedEver(listener, detectorId, enabledTime)); + } else { + DetectorProfile profile = new DetectorProfile(); + profile.setState(DetectorState.DISABLED); + listener.onResponse(profile); + } + } + + private ActionListener onInittedEver( + MultiResponsesDelegateActionListener listener, + String detectorId, + long lastUpdateTimeMs + ) { + return ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + DetectorProfile profile = new DetectorProfile(); + if (hits.getTotalHits().value == 0L) { + profile.setState(DetectorState.INIT); + } else { + profile.setState(DetectorState.RUNNING); + } + + listener.onResponse(profile); + + }, exception -> { + if (exception instanceof IndexNotFoundException) { + DetectorProfile profile = new DetectorProfile(); + // anomaly result index is not created yet + profile.setState(DetectorState.INIT); + listener.onResponse(profile); + } else { + logger + .error( + "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", + detectorId + ); + listener.onFailure(new RuntimeException("Fail to find detector state: " + detectorId, exception)); + } + }); + } + + /** + * Error is populated if error of the latest anomaly result is not empty. + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @param listener listener to process the returned error or exception + */ + private void profileError(String detectorId, long enabledTime, MultiResponsesDelegateActionListener listener) { + SearchRequest searchLatestResult = createLatestAnomalyResultRequest(detectorId, enabledTime); + client.search(searchLatestResult, onGetLatestAnomalyResult(listener, detectorId)); + } + + private ActionListener onGetLatestAnomalyResult(ActionListener listener, String detectorId) { + return ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value == 0L) { + listener.onResponse(new DetectorProfile()); + } else { + SearchHit hit = hits.getAt(0); + + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + AnomalyResult result = parser.namedObject(AnomalyResult.class, AnomalyResult.PARSE_FIELD_NAME, null); + + DetectorProfile profile = new DetectorProfile(); + if (result.getError() != null) { + profile.setError(result.getError()); + } + listener.onResponse(profile); + } catch (IOException | XContentParseException | NullPointerException e) { + logger.error("Fail to parse anomaly result with " + hit.toString()); + listener.onFailure(new RuntimeException("Fail to find detector error: " + detectorId, e)); + } + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + listener.onResponse(new DetectorProfile()); + } else { + logger.error("Fail to find any anomaly result after AD job enabled time for detector {}", detectorId); + listener.onFailure(new RuntimeException("Fail to find detector error: " + detectorId, exception)); + } + }); + } + + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private SearchRequest createInittedEverRequest(String detectorId, long enabledTime) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(AnomalyResult.ANOMALY_RESULT_INDEX); + request.source(source); + return request; + } + + /** + * Create search request to get the latest anomaly result after AD job enabled time + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private SearchRequest createLatestAnomalyResultRequest(String detectorId, long enabledTime) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + + FieldSortBuilder sortQuery = new FieldSortBuilder(AnomalyResult.EXECUTION_END_TIME_FIELD).order(SortOrder.DESC); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1).sort(sortQuery); + + SearchRequest request = new SearchRequest(AnomalyResult.ANOMALY_RESULT_INDEX); + request.source(source); + return request; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java index 5998b08f..5a613b50 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java @@ -58,6 +58,7 @@ public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureMa * @param startTime detection period start time * @param endTime detection period end time * @param listener handle anomaly result + * @throws IOException - if a user gives wrong query input when defining a detector */ public void executeDetector(AnomalyDetector detector, Instant startTime, Instant endTime, ActionListener> listener) throws IOException { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index dbf74b16..e334645e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -108,6 +109,7 @@ public String getName() { private final CheckpointDao checkpointDao; private final Gson gson; private final Clock clock; + private final int shingleSize; // A tree of N samples has 2N nodes, with one bounding box for each node. private static final long BOUNDING_BOXES = 2L; @@ -160,7 +162,8 @@ public ModelManager( Class thresholdingModelClass, int minPreviewSize, Duration modelTtl, - Duration checkpointInterval + Duration checkpointInterval, + int shingleSize ) { this.clusterService = clusterService; @@ -188,6 +191,7 @@ public ModelManager( this.forests = new ConcurrentHashMap<>(); this.thresholds = new ConcurrentHashMap<>(); + this.shingleSize = shingleSize; } /** @@ -272,6 +276,36 @@ public Entry getPartitionedForestSizes(RandomCutForest forest, return new SimpleImmutableEntry<>(numPartitions, forestSize); } + /** + * Construct a RCF model and then partition it by forest size. + * + * A RCF model is constructed based on the number of input features. + * + * Then a RCF model is first partitioned into desired size based on heap. + * If there are more partitions than the number of nodes in the cluster, + * the model is partitioned by the number of nodes and verified to + * ensure the size of a partition does not exceed the max size limit based on heap. + * + * @param detector detector object + * @return a pair of number of partitions and size of a parition (number of trees) + * @throws LimitExceededException when there is no sufficient resouce available + */ + public Entry getPartitionedForestSizes(AnomalyDetector detector) { + String detectorId = detector.getDetectorId(); + int rcfNumFeatures = detector.getEnabledFeatureIds().size() * shingleSize; + return getPartitionedForestSizes( + RandomCutForest + .builder() + .dimensions(rcfNumFeatures) + .sampleSize(rcfNumSamplesInTree) + .numberOfTrees(rcfNumTrees) + .outputAfter(rcfNumSamplesInTree) + .parallelExecutionEnabled(false) + .build(), + detectorId + ); + } + /** * Gets the estimated size of a RCF model. * @@ -542,20 +576,22 @@ public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints) { if (dataPoints.length == 0 || dataPoints[0].length == 0) { throw new IllegalArgumentException("Data points must not be empty."); } + if (dataPoints[0].length != anomalyDetector.getEnabledFeatureIds().size() * shingleSize) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Feature dimension is not correct, we expect %s but get %d", + anomalyDetector.getEnabledFeatureIds().size() * shingleSize, + dataPoints[0].length + ) + ); + } int rcfNumFeatures = dataPoints[0].length; // Create partitioned RCF models - Entry partitionResults = getPartitionedForestSizes( - RandomCutForest - .builder() - .dimensions(rcfNumFeatures) - .sampleSize(rcfNumSamplesInTree) - .numberOfTrees(rcfNumTrees) - .outputAfter(rcfNumSamplesInTree) - .parallelExecutionEnabled(false) - .build(), - anomalyDetector.getDetectorId() - ); + Entry partitionResults = getPartitionedForestSizes(anomalyDetector); + int numForests = partitionResults.getKey(); int forestSize = partitionResults.getValue(); double[] scores = new double[dataPoints.length]; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorJob.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorJob.java index 62a42cf7..bf552941 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorJob.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyDetectorJob.java @@ -96,7 +96,8 @@ public static AnomalyDetectorJob parse(XContentParser parser) throws IOException String name = null; Schedule schedule = null; TimeConfiguration windowDelay = null; - Boolean isEnabled = null; + // we cannot set it to null as isEnabled() would do the unboxing and results in null pointer exception + Boolean isEnabled = Boolean.FALSE; Instant enabledTime = null; Instant disabledTime = null; Instant lastUpdateTime = null; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyResult.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyResult.java index 077ba33d..d0881e3b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyResult.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/AnomalyResult.java @@ -18,6 +18,9 @@ import com.amazon.opendistroforelasticsearch.ad.annotation.Generated; import com.amazon.opendistroforelasticsearch.ad.util.ParseUtils; import com.google.common.base.Objects; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -35,10 +38,17 @@ */ public class AnomalyResult implements ToXContentObject { + public static final String PARSE_FIELD_NAME = "AnomalyResult"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + AnomalyResult.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + public static final String ANOMALY_RESULT_INDEX = ".opendistro-anomaly-results"; public static final String DETECTOR_ID_FIELD = "detector_id"; - private static final String ANOMALY_SCORE_FIELD = "anomaly_score"; + public static final String ANOMALY_SCORE_FIELD = "anomaly_score"; private static final String ANOMALY_GRADE_FIELD = "anomaly_grade"; private static final String CONFIDENCE_FIELD = "confidence"; private static final String FEATURE_DATA_FIELD = "feature_data"; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java new file mode 100644 index 00000000..30650cbe --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +public class DetectorProfile implements ToXContentObject, Mergeable { + private DetectorState state; + private String error; + + private static final String STATE_FIELD = "state"; + private static final String ERROR_FIELD = "error"; + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + if (state != null) { + xContentBuilder.field(STATE_FIELD, state); + } + if (error != null) { + xContentBuilder.field(ERROR_FIELD, error); + } + return xContentBuilder.endObject(); + } + + public DetectorState getState() { + return state; + } + + public void setState(DetectorState state) { + this.state = state; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + DetectorProfile otherProfile = (DetectorProfile) other; + if (otherProfile.getState() != null) { + this.state = otherProfile.getState(); + } + if (otherProfile.getError() != null) { + this.error = otherProfile.getError(); + } + + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof DetectorProfile) { + DetectorProfile other = (DetectorProfile) obj; + + return new EqualsBuilder().append(state, other.state).append(error, other.error).isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(state).append(error).toHashCode(); + } + + @Override + public String toString() { + return new ToStringBuilder(this).append("state", state).append("error", error).toString(); + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorState.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorState.java new file mode 100644 index 00000000..08307942 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorState.java @@ -0,0 +1,22 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.model; + +public enum DetectorState { + DISABLED, + INIT, + RUNNING +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/Mergeable.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/Mergeable.java new file mode 100644 index 00000000..7093af99 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/Mergeable.java @@ -0,0 +1,20 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.model; + +public interface Mergeable { + void merge(Mergeable other); +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java new file mode 100644 index 00000000..ea0be275 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java @@ -0,0 +1,73 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.model; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +public enum ProfileName { + STATE("state"), + ERROR("error"); + + private String name; + + ProfileName(String name) { + this.name = name; + } + + /** + * Get profile name + * + * @return name + */ + public String getName() { + return name; + } + + /** + * Get set of profile names + * + * @return set of profile names + */ + public static Set getNames() { + Set names = new HashSet<>(); + + for (ProfileName statName : ProfileName.values()) { + names.add(statName.getName()); + } + return names; + } + + public static ProfileName getName(String name) { + switch (name) { + case "state": + return STATE; + case "error": + return ERROR; + default: + throw new IllegalArgumentException("Unsupported profile types"); + } + } + + public static Set getNames(Collection names) { + Set res = new HashSet<>(); + for (String name : names) { + res.add(getName(name)); + } + return res; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/AbstractSearchAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/AbstractSearchAction.java index ef4a4137..3f98befd 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/AbstractSearchAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/AbstractSearchAction.java @@ -80,13 +80,18 @@ public RestResponse buildResponse(SearchResponse response) throws Exception { return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString()); } - for (SearchHit hit : response.getHits()) { - XContentParser parser = XContentType.JSON - .xContent() - .createParser(channel.request().getXContentRegistry(), LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString()); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + if (clazz == AnomalyDetector.class) { + for (SearchHit hit : response.getHits()) { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + channel.request().getXContentRegistry(), + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString() + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); - if (clazz == AnomalyDetector.class) { + // write back id and version to anomaly detector object ToXContentObject xContentObject = AnomalyDetector.parse(parser, hit.getId(), hit.getVersion()); XContentBuilder builder = xContentObject.toXContent(jsonBuilder(), EMPTY_PARAMS); hit.sourceRef(BytesReference.bytes(builder)); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java index 82582ad7..21e142e6 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -17,8 +17,13 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorProfile; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; import com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils; +import com.google.common.collect.Sets; import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorProfileRunner; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -27,6 +32,7 @@ import org.elasticsearch.action.get.MultiGetRequest; import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.rest.BaseRestHandler; @@ -40,11 +46,16 @@ import org.elasticsearch.rest.action.RestResponseListener; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; import java.util.Locale; +import java.util.Set; import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.DETECTOR_ID; +import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.PROFILE; +import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.TYPE; import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.createXContentParser; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; @@ -55,11 +66,36 @@ public class RestGetAnomalyDetectorAction extends BaseRestHandler { private static final String GET_ANOMALY_DETECTOR_ACTION = "get_anomaly_detector"; private static final Logger logger = LogManager.getLogger(RestGetAnomalyDetectorAction.class); + private final AnomalyDetectorProfileRunner profileRunner; + private final Set allProfileTypeStrs; + private final Set allProfileTypes; + + public RestGetAnomalyDetectorAction( + RestController controller, + AnomalyDetectorProfileRunner profileRunner, + Set allProfileTypeStrs + ) { + this.profileRunner = profileRunner; + this.allProfileTypes = new HashSet(Arrays.asList(ProfileName.values())); + this.allProfileTypeStrs = ProfileName.getNames(); - public RestGetAnomalyDetectorAction(RestController controller) { String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); controller.registerHandler(RestRequest.Method.GET, path, this); controller.registerHandler(RestRequest.Method.HEAD, path, this); + controller + .registerHandler( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), + this + ); + // types is a profile names. See a complete list of supported profiles names in + // com.amazon.opendistroforelasticsearch.ad.model.ProfileName. + controller + .registerHandler( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE), + this + ); } @Override @@ -71,16 +107,23 @@ public String getName() { protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String detectorId = request.param(DETECTOR_ID); boolean returnJob = request.paramAsBoolean("job", false); - MultiGetRequest.Item adItem = new MultiGetRequest.Item(ANOMALY_DETECTORS_INDEX, detectorId) - .version(RestActions.parseVersion(request)); - MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); - if (returnJob) { - MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(ANOMALY_DETECTOR_JOB_INDEX, detectorId) + String typesStr = request.param(TYPE); + String rawPath = request.rawPath(); + if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + return channel -> profileRunner + .profile(detectorId, getProfileActionListener(channel, detectorId), getProfilesToCollect(typesStr)); + } else { + MultiGetRequest.Item adItem = new MultiGetRequest.Item(ANOMALY_DETECTORS_INDEX, detectorId) .version(RestActions.parseVersion(request)); - multiGetRequest.add(adJobItem); - } + MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); + if (returnJob) { + MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(ANOMALY_DETECTOR_JOB_INDEX, detectorId) + .version(RestActions.parseVersion(request)); + multiGetRequest.add(adJobItem); + } - return channel -> client.multiGet(multiGetRequest, onMultiGetResponse(channel, returnJob, detectorId)); + return channel -> client.multiGet(multiGetRequest, onMultiGetResponse(channel, returnJob, detectorId)); + } } private ActionListener onMultiGetResponse(RestChannel channel, boolean returnJob, String detectorId) { @@ -110,12 +153,8 @@ public RestResponse buildResponse(MultiGetResponse multiGetResponse) throws Exce ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); detector = parser.namedObject(AnomalyDetector.class, AnomalyDetector.PARSE_FIELD_NAME, null); - } catch (Throwable t) { - logger.error("Fail to parse detector", t); - return new BytesRestResponse( - RestStatus.INTERNAL_SERVER_ERROR, - "Failed to parse detector with id: " + detectorId - ); + } catch (Exception e) { + return buildInternalServerErrorResponse(e, "Failed to parse detector with id: " + detectorId); } } } @@ -127,12 +166,8 @@ public RestResponse buildResponse(MultiGetResponse multiGetResponse) throws Exce try (XContentParser parser = createXContentParser(channel, response.getResponse().getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); adJob = AnomalyDetectorJob.parse(parser); - } catch (Throwable t) { - logger.error("Fail to parse detector job ", t); - return new BytesRestResponse( - RestStatus.INTERNAL_SERVER_ERROR, - "Failed to parse detector job with id: " + detectorId - ); + } catch (Exception e) { + return buildInternalServerErrorResponse(e, "Failed to parse detector job with id: " + detectorId); } } } @@ -148,4 +183,25 @@ public RestResponse buildResponse(MultiGetResponse multiGetResponse) throws Exce }; } + private ActionListener getProfileActionListener(RestChannel channel, String detectorId) { + return ActionListener + .wrap( + profile -> { channel.sendResponse(new BytesRestResponse(RestStatus.OK, profile.toXContent(channel.newBuilder()))); }, + exception -> { channel.sendResponse(buildInternalServerErrorResponse(exception, exception.getMessage())); } + ); + } + + private RestResponse buildInternalServerErrorResponse(Exception e, String errorMsg) { + logger.error(errorMsg, e); + return new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, errorMsg); + } + + private Set getProfilesToCollect(String typesStr) { + if (Strings.isEmpty(typesStr)) { + return this.allProfileTypes; + } else { + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return ProfileName.getNames(Sets.intersection(this.allProfileTypeStrs, typesInRequest)); + } + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java index eb57123c..4e993c50 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -22,19 +22,18 @@ import java.time.Instant; import java.util.Map; import java.util.Optional; -import java.util.Random; import java.util.AbstractMap.SimpleEntry; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; -import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.client.Client; @@ -44,8 +43,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import com.amazon.randomcutforest.RandomCutForest; - /** * ADStateManager is used by transport layer to manage AnomalyDetector object * and the number of partitions for a detector id. @@ -56,7 +53,6 @@ public class ADStateManager { private ConcurrentHashMap> currentDetectors; private ConcurrentHashMap> partitionNumber; private Client client; - private Random random; private ModelManager modelManager; private NamedXContentRegistry xContentRegistry; private ClientUtil clientUtil; @@ -77,7 +73,6 @@ public ADStateManager( ) { this.currentDetectors = new ConcurrentHashMap<>(); this.client = client; - this.random = new Random(); this.modelManager = modelManager; this.xContentRegistry = xContentRegistry; this.partitionNumber = new ConcurrentHashMap<>(); @@ -91,67 +86,58 @@ public ADStateManager( /** * Get the number of RCF model's partition number for detector adID * @param adID detector id + * @param detector object * @return the number of RCF model's partition number for adID - * @throws InterruptedException when we cannot get anomaly detector object for adID before timeout * @throws LimitExceededException when there is no sufficient resource available */ - public int getPartitionNumber(String adID) throws InterruptedException { + public int getPartitionNumber(String adID, AnomalyDetector detector) { Entry partitonAndTime = partitionNumber.get(adID); if (partitonAndTime != null) { partitonAndTime.setValue(clock.instant()); return partitonAndTime.getKey(); } - Optional detector = getAnomalyDetector(adID); - if (!detector.isPresent()) { - throw new AnomalyDetectionException(adID, "AnomalyDetector is not found"); - } - - RandomCutForest forest = RandomCutForest - .builder() - .dimensions(detector.get().getFeatureAttributes().size() * AnomalyDetectorSettings.SHINGLE_SIZE) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) - .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) - .parallelExecutionEnabled(false) - .build(); - int partitionNum = modelManager.getPartitionedForestSizes(forest, adID).getKey(); + int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey(); partitionNumber.putIfAbsent(adID, new SimpleEntry<>(partitionNum, clock.instant())); return partitionNum; } - public Optional getAnomalyDetector(String adID) { + public void getAnomalyDetector(String adID, ActionListener> listener) { Entry detectorAndTime = currentDetectors.get(adID); if (detectorAndTime != null) { detectorAndTime.setValue(clock.instant()); - return Optional.of(detectorAndTime.getKey()); + listener.onResponse(Optional.of(detectorAndTime.getKey())); + return; } GetRequest request = new GetRequest(AnomalyDetector.ANOMALY_DETECTORS_INDEX, adID); - Optional getResponse = clientUtil.timedRequest(request, LOG, client::get); - - return onGetResponse(getResponse, adID); + clientUtil.asyncRequest(request, client::get, onGetResponse(adID, listener)); } - private Optional onGetResponse(Optional asResponse, String adID) { - if (!asResponse.isPresent() || !asResponse.get().isExists()) { - return Optional.empty(); - } + private ActionListener onGetResponse(String adID, ActionListener> listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Optional.empty()); + return; + } - GetResponse response = asResponse.get(); - String xc = response.getSourceAsString(); - LOG.debug("Fetched anomaly detector: {}", xc); - - try (XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc)) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); - currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant())); - return Optional.of(detector); - } catch (Exception t) { - LOG.error("Fail to parse detector {}", adID); - LOG.error("Stack trace:", t); - return Optional.empty(); - } + String xc = response.getSourceAsString(); + LOG.info("Fetched anomaly detector: {}", xc); + + try ( + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); + currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant())); + listener.onResponse(Optional.of(detector)); + } catch (Exception t) { + LOG.error("Fail to parse detector {}", adID); + LOG.error("Stack trace:", t); + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); } /** diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 4d3f62a9..f130ab9c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -32,9 +32,9 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.common.exception.ResourceNotFoundException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.feature.SinglePointFeatures; -import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; import com.amazon.opendistroforelasticsearch.ad.ml.rcf.CombinedRcfResult; @@ -45,6 +45,7 @@ import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; import com.amazon.opendistroforelasticsearch.ad.util.ColdStartRunner; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -54,11 +55,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -69,12 +68,12 @@ import org.elasticsearch.common.io.stream.NotSerializableExceptionWrapper; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexNotFoundException; public class AnomalyResultTransportAction extends HandledTransportAction { @@ -83,6 +82,7 @@ public class AnomalyResultTransportAction extends HandledTransportAction getFeatureData(double[] currentFeature, AnomalyDetecto * + unknown prediction error * * Known cause of EndRunException with endNow returning true: - * + anomaly detector is not available - * + a models' memory size reached limit + * + a model's memory size reached limit * + models' total memory size reached limit + * + Having trouble querying feature data due to + * * index does not exist + * * all features have been disabled + * + anomaly detector is not available + * * Known cause of InternalFailure: * + threshold model node is not available * + cluster read/write is blocked - * + interrupted while waiting for rcf/threshold model nodes' responses * + cold start hasn't been finished * + fail to get all of rcf model nodes' responses * + fail to get threshold model node's response * + RCF/Threshold model node failing to get checkpoint to restore model before timeout + * + Detection is throttle because previous detection query is running * */ @Override @@ -225,7 +214,18 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } try { - Optional detector = stateManager.getAnomalyDetector(adID); + stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); + } catch (Exception ex) { + handleExecuteException(ex, listener, adID); + } + } + + private ActionListener> onGetDetector( + ActionListener listener, + String adID, + AnomalyResultRequest request + ) { + return ActionListener.wrap(detector -> { if (!detector.isPresent()) { listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); return; @@ -233,13 +233,15 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< AnomalyDetector anomalyDetector = detector.get(); String thresholdModelID = modelManager.getThresholdModelId(adID); - Optional thresholdNode = hashRing.getOwningNode(thresholdModelID); - if (!thresholdNode.isPresent()) { + Optional asThresholdNode = hashRing.getOwningNode(thresholdModelID); + if (!asThresholdNode.isPresent()) { listener.onFailure(new InternalFailure(adID, "Threshold model node is not available.")); return; } - if (!shouldStart(listener, adID, detector.get(), thresholdNode.get().getId(), thresholdModelID)) { + DiscoveryNode thresholdNode = asThresholdNode.get(); + + if (!shouldStart(listener, adID, anomalyDetector, thresholdNode.getId(), thresholdModelID)) { return; } @@ -250,12 +252,31 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< long dataStartTime = request.getStart() - delayMillis; long dataEndTime = request.getEnd() - delayMillis; - SinglePointFeatures featureOptional = featureManager.getCurrentFeatures(anomalyDetector, dataStartTime, dataEndTime); + featureManager + .getCurrentFeatures( + anomalyDetector, + dataStartTime, + dataEndTime, + onFeatureResponse(adID, anomalyDetector, listener, thresholdModelID, thresholdNode, dataStartTime, dataEndTime) + ); + }, exception -> handleExecuteException(exception, listener, adID)); + + } + private ActionListener onFeatureResponse( + String adID, + AnomalyDetector detector, + ActionListener listener, + String thresholdModelID, + DiscoveryNode thresholdNode, + long dataStartTime, + long dataEndTime + ) { + return ActionListener.wrap(featureOptional -> { List featureInResponse = null; if (featureOptional.getUnprocessedFeatures().isPresent()) { - featureInResponse = getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector.get()); + featureInResponse = getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); } if (!featureOptional.getProcessedFeatures().isPresent()) { @@ -293,7 +314,7 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< // Can throw LimitExceededException when a single partition is more than X% of heap memory. // Compute this number once and the value won't change unless the coordinating AD node for an // detector changes or the cluster size changes. - int rcfPartitionNum = stateManager.getPartitionNumber(adID); + int rcfPartitionNum = stateManager.getPartitionNumber(adID, detector); List rcfResults = new ArrayList<>(); @@ -326,8 +347,6 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< thresholdModelID, thresholdNode, featureInResponse, - dataStartTime, - dataEndTime, rcfPartitionNum, responseCount, adID @@ -342,9 +361,15 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) ); } - } catch (Exception ex) { - handleExecuteException(ex, listener, adID); - } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true)); + } else if (exception instanceof IllegalArgumentException && detector.getEnabledFeatureIds().isEmpty()) { + listener.onFailure(new EndRunException(adID, ALL_FEATURES_DISABLED_ERR_MSG, true)); + } else { + handleExecuteException(exception, listener, adID); + } + }); } /** @@ -390,7 +415,9 @@ private void findException(Throwable cause, String adID, AtomicReference previousException = globalRunner.fetchException(adID); @@ -468,13 +495,11 @@ class RCFActionListener implements ActionListener { private String modelID; private AtomicReference failure; private String rcfNodeID; - private Optional detector; + private AnomalyDetector detector; private ActionListener listener; private String thresholdModelID; - private Optional thresholdNode; + private DiscoveryNode thresholdNode; private List featureInResponse; - private long startTime; - private long endTime; private int nodeCount; private final AtomicInteger responseCount; private final String adID; @@ -484,13 +509,11 @@ class RCFActionListener implements ActionListener { String modelID, AtomicReference failure, String rcfNodeID, - Optional detector, + AnomalyDetector detector, ActionListener listener, String thresholdModelID, - Optional thresholdNode, + DiscoveryNode thresholdNode, List features, - long startTime, - long endTime, int nodeCount, AtomicInteger responseCount, String adID @@ -504,8 +527,6 @@ class RCFActionListener implements ActionListener { this.thresholdModelID = thresholdModelID; this.featureInResponse = features; this.failure = failure; - this.startTime = startTime; - this.endTime = endTime; this.nodeCount = nodeCount; this.responseCount = responseCount; this.adID = adID; @@ -544,7 +565,7 @@ public void onFailure(Exception e) { private void handleRCFResults() { try { - if (coldStartIfNoModel(failure, detector.get()) || rcfResults.isEmpty()) { + if (coldStartIfNoModel(failure, detector) || rcfResults.isEmpty()) { listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); return; } @@ -554,7 +575,7 @@ private void handleRCFResults() { final AtomicReference anomalyResultResponse = new AtomicReference<>(); - String thresholdNodeId = thresholdNode.get().getId(); + String thresholdNodeId = thresholdNode.getId(); LOG.info("Sending threshold request to {} for model {}", thresholdNodeId, thresholdModelID); ThresholdActionListener thresholdListener = new ThresholdActionListener( anomalyResultResponse, @@ -563,15 +584,12 @@ private void handleRCFResults() { thresholdNodeId, detector, combinedResult, - featureInResponse, listener, - startTime, - endTime, adID ); transportService .sendRequest( - thresholdNode.get(), + thresholdNode, ThresholdResultAction.NAME, new ThresholdResultRequest(adID, thresholdModelID, combinedScore), option, @@ -590,11 +608,8 @@ class ThresholdActionListener implements ActionListener private AtomicReference failure; private String thresholdNodeID; private ActionListener listener; - private Optional detector; + private AnomalyDetector detector; private CombinedRcfResult combinedResult; - private List featureInResponse; - private long startTime; - private long endTime; private String adID; ThresholdActionListener( @@ -602,12 +617,9 @@ class ThresholdActionListener implements ActionListener List features, String modelID, String thresholdNodeID, - Optional detector, + AnomalyDetector detector, CombinedRcfResult combinedResult, - List featureInResponse, ActionListener listener, - long startTime, - long endTime, String adID ) { this.anomalyResultResponse = anomalyResultResponse; @@ -616,9 +628,6 @@ class ThresholdActionListener implements ActionListener this.thresholdNodeID = thresholdNodeID; this.detector = detector; this.combinedResult = combinedResult; - this.featureInResponse = featureInResponse; - this.startTime = startTime; - this.endTime = endTime; this.failure = new AtomicReference(); this.listener = listener; this.adID = adID; @@ -650,7 +659,7 @@ public void onFailure(Exception e) { private void handleThresholdResult() { try { - if (coldStartIfNoModel(failure, detector.get())) { + if (coldStartIfNoModel(failure, detector)) { listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); return; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java index 41e11052..382c632c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; -import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -57,8 +57,21 @@ protected void doExecute(Task task, RCFResultRequest request, ActionListener listener + .onResponse(new RCFResultResponse(result.getScore(), result.getConfidence(), result.getForestSize())), + exception -> { + LOG.warn(exception); + listener.onFailure(exception); + } + ) + ); } catch (Exception e) { LOG.error(e); listener.onFailure(e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java index fb140a8a..d434837e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java @@ -15,7 +15,6 @@ package com.amazon.opendistroforelasticsearch.ad.transport; -import com.amazon.opendistroforelasticsearch.ad.cluster.DeleteDetector; import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -39,20 +38,17 @@ public class StopDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); @@ -42,8 +42,17 @@ protected void doExecute(Task task, ThresholdResultRequest request, ActionListen try { LOG.info("Serve threshold request for {}", request.getModelID()); - ThresholdingResult result = manager.getThresholdingResult(request.getAdID(), request.getModelID(), request.getRCFScore()); - listener.onResponse(new ThresholdResultResponse(result.getGrade(), result.getConfidence())); + manager + .getThresholdingResult( + request.getAdID(), + request.getModelID(), + request.getRCFScore(), + ActionListener + .wrap( + result -> listener.onResponse(new ThresholdResultResponse(result.getGrade(), result.getConfidence())), + exception -> listener.onFailure(exception) + ) + ); } catch (Exception e) { LOG.error(e); listener.onFailure(e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java index 742a0689..e9a970cc 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java @@ -172,35 +172,27 @@ private void saveDetectorResult(AnomalyResult anomalyResult) { } void saveDetectorResult(IndexRequest indexRequest, String context, Iterator backoff) { - client - .index( - indexRequest, - ActionListener - .wrap( - response -> LOG.debug(SUCCESS_SAVING_MSG + context), - exception -> { - // Elasticsearch has a thread pool and a queue for write per node. A thread - // pool will have N number of workers ready to handle the requests. When a - // request comes and if a worker is free , this is handled by the worker. Now by - // default the number of workers is equal to the number of cores on that CPU. - // When the workers are full and there are more write requests, the request - // will go to queue. The size of queue is also limited. If by default size is, - // say, 200 and if there happens more parallel requests than this, then those - // requests would be rejected as you can see EsRejectedExecutionException. - // So EsRejectedExecutionException is the way that Elasticsearch tells us that - // it cannot keep up with the current indexing rate. - // When it happens, we should pause indexing a bit before trying again, ideally - // with randomized exponential backoff. - if (!(exception instanceof EsRejectedExecutionException) || !backoff.hasNext()) { - LOG.error(FAIL_TO_SAVE_ERR_MSG + context, exception); - } else { - TimeValue nextDelay = backoff.next(); - LOG.warn(RETRY_SAVING_ERR_MSG + context, exception); - threadPool - .schedule(() -> saveDetectorResult(indexRequest, context, backoff), nextDelay, ThreadPool.Names.SAME); - } - } - ) - ); + client.index(indexRequest, ActionListener.wrap(response -> LOG.debug(SUCCESS_SAVING_MSG + context), exception -> { + // Elasticsearch has a thread pool and a queue for write per node. A thread + // pool will have N number of workers ready to handle the requests. When a + // request comes and if a worker is free , this is handled by the worker. Now by + // default the number of workers is equal to the number of cores on that CPU. + // When the workers are full and there are more write requests, the request + // will go to queue. The size of queue is also limited. If by default size is, + // say, 200 and if there happens more parallel requests than this, then those + // requests would be rejected as you can see EsRejectedExecutionException. + // So EsRejectedExecutionException is the way that Elasticsearch tells us that + // it cannot keep up with the current indexing rate. + // When it happens, we should pause indexing a bit before trying again, ideally + // with randomized exponential backoff. + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (!(cause instanceof EsRejectedExecutionException) || !backoff.hasNext()) { + LOG.error(FAIL_TO_SAVE_ERR_MSG + context, cause); + } else { + TimeValue nextDelay = backoff.next(); + LOG.warn(RETRY_SAVING_ERR_MSG + context, cause); + threadPool.schedule(() -> saveDetectorResult(indexRequest, context, backoff), nextDelay, ThreadPool.Names.SAME); + } + })); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/MultiResponsesDelegateActionListener.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/MultiResponsesDelegateActionListener.java new file mode 100644 index 00000000..3f42a18c --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/MultiResponsesDelegateActionListener.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; + +import com.amazon.opendistroforelasticsearch.ad.model.Mergeable; + +/** + * A listener wrapper to help send multiple requests asynchronously and return one final responses together + */ +public class MultiResponsesDelegateActionListener implements ActionListener { + private static final Logger LOG = LogManager.getLogger(MultiResponsesDelegateActionListener.class); + private final ActionListener delegate; + private final AtomicInteger collectedResponseCount; + private final int maxResponseCount; + // save responses from multiple requests + private final List savedResponses; + private List exceptions; + private String finalErrorMsg; + + public MultiResponsesDelegateActionListener(ActionListener delegate, int maxResponseCount, String finalErrorMsg) { + this.delegate = delegate; + this.collectedResponseCount = new AtomicInteger(0); + this.maxResponseCount = maxResponseCount; + this.savedResponses = Collections.synchronizedList(new ArrayList()); + this.exceptions = Collections.synchronizedList(new ArrayList()); + this.finalErrorMsg = finalErrorMsg; + } + + @Override + public void onResponse(T response) { + try { + if (response != null) { + this.savedResponses.add(response); + } + } finally { + // If expectedResponseCount == 0 , collectedResponseCount.incrementAndGet() will be greater than expectedResponseCount + if (collectedResponseCount.incrementAndGet() >= maxResponseCount) { + finish(); + } + } + + } + + @Override + public void onFailure(Exception e) { + LOG.error(e); + try { + this.exceptions.add(e.getMessage()); + } finally { + // no matter the asynchronous request is a failure or success, we need to increment the count. + // We need finally here to increment the count when there is a failure. + if (collectedResponseCount.incrementAndGet() >= maxResponseCount) { + finish(); + } + } + } + + private void finish() { + if (this.exceptions.size() == 0) { + if (savedResponses.size() == 0) { + this.delegate.onFailure(new RuntimeException("No response collected")); + } else { + T response0 = savedResponses.get(0); + for (int i = 1; i < savedResponses.size(); i++) { + response0.merge(savedResponses.get(i)); + } + this.delegate.onResponse(response0); + } + } else { + this.delegate.onFailure(new RuntimeException(String.format(Locale.ROOT, finalErrorMsg + " Exceptions: %s", exceptions))); + } + } + + public void failImmediately(Exception e) { + this.delegate.onFailure(new RuntimeException(finalErrorMsg, e)); + } + + public void failImmediately(String errMsg) { + this.delegate.onFailure(new RuntimeException(errMsg)); + } + + public void failImmediately(String errMsg, Exception e) { + this.delegate.onFailure(new RuntimeException(errMsg, e)); + } + + public void respondImmediately(T o) { + this.delegate.onResponse(o); + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/RestHandlerUtils.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/RestHandlerUtils.java index 2d09ea2d..057eeec7 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/RestHandlerUtils.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/RestHandlerUtils.java @@ -53,6 +53,8 @@ public final class RestHandlerUtils { public static final String PREVIEW = "_preview"; public static final String START_JOB = "_start"; public static final String STOP_JOB = "_stop"; + public static final String PROFILE = "_profile"; + public static final String TYPE = "type"; public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); private static final String KIBANA_USER_AGENT = "Kibana"; diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 2ad310d1..80ee69e4 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -45,6 +45,9 @@ "execution_end_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "error": { + "type": "text" } } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java index 0f93ed59..45b23ea0 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java @@ -19,6 +19,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.JobExecutionContext; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.LockModel; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.ScheduledJobParameter; @@ -26,6 +27,7 @@ import com.amazon.opendistroforelasticsearch.jobscheduler.spi.schedule.Schedule; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.utils.LockService; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; @@ -74,6 +76,10 @@ public class AnomalyDetectorJobRunnerTests extends AbstractADTest { @Mock private Client client; + + @Mock + private ClientUtil clientUtil; + @Mock private ClusterService clusterService; @@ -119,6 +125,7 @@ public void setup() throws Exception { doReturn(executorService).when(mockedThreadPool).executor(anyString()); runner.setThreadPool(mockedThreadPool); runner.setClient(client); + runner.setClientUtil(clientUtil); runner.setAnomalyResultHandler(anomalyResultHandler); setUpJobParameter(); @@ -214,7 +221,7 @@ public void testRunAdJobWithEndRunExceptionNow() { public void testRunAdJobWithEndRunExceptionNowAndExistingAdJob() { testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, true); verify(anomalyResultHandler).indexAnomalyResult(any()); - verify(client).index(any(), any()); + verify(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); assertTrue(testAppender.containsMessage("AD Job was disabled by JobRunner for")); } @@ -222,7 +229,7 @@ public void testRunAdJobWithEndRunExceptionNowAndExistingAdJob() { public void testRunAdJobWithEndRunExceptionNowAndExistingAdJobAndIndexException() { testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, false); verify(anomalyResultHandler).indexAnomalyResult(any()); - verify(client).index(any(), any()); + verify(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); assertTrue(testAppender.containsMessage("Failed to disable AD job for")); } @@ -256,7 +263,7 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); GetResponse response = new GetResponse( new GetResult( AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, @@ -286,11 +293,11 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b listener.onResponse(response); return null; - }).when(client).get(any(), any()); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); doAnswer(invocation -> { IndexRequest request = invocation.getArgument(0); - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); ShardId shardId = new ShardId(new Index(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, randomAlphaOfLength(10)), 0); if (disableSuccessfully) { listener.onResponse(new IndexResponse(shardId, randomAlphaOfLength(10), request.id(), 1, 1, 1, true)); @@ -298,7 +305,7 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b listener.onResponse(null); } return null; - }).when(client).index(any(), any()); + }).when(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); } @@ -309,10 +316,10 @@ public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("test")); return null; - }).when(client).get(any(), any()); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); verify(anomalyResultHandler).indexAnomalyResult(any()); @@ -324,7 +331,7 @@ public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); - doThrow(new RuntimeException("fail to get AD job")).when(client).get(any(), any()); + doThrow(new RuntimeException("fail to get AD job")).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); verify(anomalyResultHandler).indexAnomalyResult(any()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java new file mode 100644 index 00000000..8f23f7ea --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -0,0 +1,317 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad; + +import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; +import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.junit.BeforeClass; + +import com.amazon.opendistroforelasticsearch.ad.cluster.ADMetaData; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorProfile; +import com.amazon.opendistroforelasticsearch.ad.model.DetectorState; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; + +public class AnomalyDetectorProfileRunnerTests extends ESTestCase { + private static final Logger LOG = LogManager.getLogger(AnomalyDetectorProfileRunnerTests.class); + private AnomalyDetectorProfileRunner runner; + private Client client; + private AnomalyDetector detector; + private static Set stateOnly; + private static Set stateNError; + private static String error = "No full shingle in current detection window"; + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + List entries = searchModule.getNamedXContents(); + entries.addAll(Arrays.asList(AnomalyDetector.XCONTENT_REGISTRY, ADMetaData.XCONTENT_REGISTRY, AnomalyResult.XCONTENT_REGISTRY)); + return new NamedXContentRegistry(entries); + } + + @BeforeClass + public static void setUpOnce() { + stateOnly = new HashSet(); + stateOnly.add(ProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(ProfileName.ERROR); + stateNError.add(ProfileName.STATE); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + runner = new AnomalyDetectorProfileRunner(client, xContentRegistry()); + } + + enum JobStatus { + INDEX_NOT_EXIT, + DISABLED, + ENABLED + } + + enum InittedEverResultStatus { + INDEX_NOT_EXIT, + GREATER_THAN_ZERO, + EMPTY, + EXCEPTION + } + + enum ErrorResultStatus { + INDEX_NOT_EXIT, + NO_ERROR, + ERROR + } + + @SuppressWarnings("unchecked") + private void setUpClientGet(boolean detectorExists, JobStatus jobStatus) throws IOException { + detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(ANOMALY_DETECTORS_INDEX)) { + if (detectorExists) { + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId())); + } else { + listener.onFailure(new IndexNotFoundException(ANOMALY_DETECTORS_INDEX)); + } + } else { + AnomalyDetectorJob job = null; + switch (jobStatus) { + case INDEX_NOT_EXIT: + listener.onFailure(new IndexNotFoundException(ANOMALY_DETECTOR_JOB_INDEX)); + break; + case DISABLED: + job = TestHelpers.randomAnomalyDetectorJob(false); + listener.onResponse(TestHelpers.createGetResponse(job, detector.getDetectorId())); + break; + case ENABLED: + job = TestHelpers.randomAnomalyDetectorJob(true); + listener.onResponse(TestHelpers.createGetResponse(job, detector.getDetectorId())); + break; + default: + assertTrue("should not reach here", false); + break; + } + } + + return null; + }).when(client).get(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setUpClientSearch(InittedEverResultStatus inittedEverResultStatus, ErrorResultStatus errorResultStatus) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + SearchRequest request = (SearchRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + if (errorResultStatus == ErrorResultStatus.INDEX_NOT_EXIT + || inittedEverResultStatus == InittedEverResultStatus.INDEX_NOT_EXIT) { + listener.onFailure(new IndexNotFoundException(AnomalyResult.ANOMALY_RESULT_INDEX)); + return null; + } + AnomalyResult result = null; + if (request.source().query().toString().contains(AnomalyResult.ANOMALY_SCORE_FIELD)) { + switch (inittedEverResultStatus) { + case GREATER_THAN_ZERO: + result = TestHelpers.randomAnomalyDetectResult(0.87); + listener.onResponse(TestHelpers.createSearchResponse(result)); + break; + case EMPTY: + listener.onResponse(TestHelpers.createEmptySearchResponse()); + break; + case EXCEPTION: + listener.onFailure(new RuntimeException()); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else { + switch (errorResultStatus) { + case NO_ERROR: + result = TestHelpers.randomAnomalyDetectResult(null); + listener.onResponse(TestHelpers.createSearchResponse(result)); + break; + case ERROR: + result = TestHelpers.randomAnomalyDetectResult(error); + listener.onResponse(TestHelpers.createSearchResponse(result)); + break; + default: + assertTrue("should not reach here", false); + break; + } + } + + return null; + }).when(client).search(any(), any()); + + } + + public void testDetectorNotExist() throws IOException, InterruptedException { + setUpClientGet(false, JobStatus.INDEX_NOT_EXIT); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile("x123", ActionListener.wrap(response -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception.getMessage().contains(AnomalyDetectorProfileRunner.FAIL_TO_FIND_DETECTOR_MSG)); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testDisabledJobIndexTemplate(JobStatus status) throws IOException, InterruptedException { + setUpClientGet(true, status); + DetectorProfile expectedProfile = new DetectorProfile(); + expectedProfile.setState(DetectorState.DISABLED); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getDetectorId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateOnly); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testNoJobIndex() throws IOException, InterruptedException { + testDisabledJobIndexTemplate(JobStatus.INDEX_NOT_EXIT); + } + + public void testJobDisabled() throws IOException, InterruptedException { + testDisabledJobIndexTemplate(JobStatus.DISABLED); + } + + public void testInitOrRunningStateTemplate(InittedEverResultStatus status, DetectorState expectedState) throws IOException, + InterruptedException { + setUpClientGet(true, JobStatus.ENABLED); + setUpClientSearch(status, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile(); + expectedProfile.setState(expectedState); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getDetectorId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateOnly); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testResultNotExist() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(InittedEverResultStatus.INDEX_NOT_EXIT, DetectorState.INIT); + } + + public void testResultEmpty() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(InittedEverResultStatus.EMPTY, DetectorState.INIT); + } + + public void testResultGreaterThanZero() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(InittedEverResultStatus.GREATER_THAN_ZERO, DetectorState.RUNNING); + } + + public void testErrorStateTemplate(InittedEverResultStatus initStatus, ErrorResultStatus status, DetectorState state, String error) + throws IOException, + InterruptedException { + setUpClientGet(true, JobStatus.ENABLED); + setUpClientSearch(initStatus, status); + DetectorProfile expectedProfile = new DetectorProfile(); + expectedProfile.setState(state); + expectedProfile.setError(error); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getDetectorId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitNoError() throws IOException, InterruptedException { + testErrorStateTemplate(InittedEverResultStatus.INDEX_NOT_EXIT, ErrorResultStatus.INDEX_NOT_EXIT, DetectorState.INIT, null); + } + + public void testRunningNoError() throws IOException, InterruptedException { + testErrorStateTemplate(InittedEverResultStatus.GREATER_THAN_ZERO, ErrorResultStatus.NO_ERROR, DetectorState.RUNNING, null); + } + + public void testRunningWithError() throws IOException, InterruptedException { + testErrorStateTemplate(InittedEverResultStatus.GREATER_THAN_ZERO, ErrorResultStatus.ERROR, DetectorState.RUNNING, error); + } + + public void testInitWithError() throws IOException, InterruptedException { + testErrorStateTemplate(InittedEverResultStatus.EMPTY, ErrorResultStatus.ERROR, DetectorState.INIT, error); + } + + public void testExceptionOnStateFetching() throws IOException, InterruptedException { + setUpClientGet(true, JobStatus.ENABLED); + setUpClientSearch(InittedEverResultStatus.EXCEPTION, ErrorResultStatus.NO_ERROR); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getDetectorId(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Unexcpeted exception " + exception.getMessage(), exception instanceof RuntimeException); + inProgressLatch.countDown(); + }), stateOnly); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java index b85103c5..3af20295 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java @@ -33,9 +33,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.util.Strings; +import org.apache.lucene.search.TotalHits; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; @@ -58,14 +62,25 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.get.GetResult; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.search.profile.SearchProfileShardResults; +import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.ESRestTestCase; @@ -82,6 +97,7 @@ import static org.elasticsearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; import static org.elasticsearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; +import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomInt; @@ -297,9 +313,21 @@ public static FeatureData randomFeatureData() { } public static AnomalyResult randomAnomalyDetectResult() { + return randomAnomalyDetectResult(randomDouble(), randomAlphaOfLength(5)); + } + + public static AnomalyResult randomAnomalyDetectResult(double score) { + return randomAnomalyDetectResult(randomDouble(), null); + } + + public static AnomalyResult randomAnomalyDetectResult(String error) { + return randomAnomalyDetectResult(Double.NaN, error); + } + + public static AnomalyResult randomAnomalyDetectResult(double score, String error) { return new AnomalyResult( randomAlphaOfLength(5), - randomDouble(), + score, randomDouble(), randomDouble(), ImmutableList.of(randomFeatureData(), randomFeatureData()), @@ -307,16 +335,20 @@ public static AnomalyResult randomAnomalyDetectResult() { Instant.now().truncatedTo(ChronoUnit.SECONDS), Instant.now().truncatedTo(ChronoUnit.SECONDS), Instant.now().truncatedTo(ChronoUnit.SECONDS), - randomAlphaOfLength(5) + error ); } public static AnomalyDetectorJob randomAnomalyDetectorJob() { + return randomAnomalyDetectorJob(true); + } + + public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled) { return new AnomalyDetectorJob( randomAlphaOfLength(10), randomIntervalSchedule(), randomIntervalTimeConfiguration(), - true, + enabled, Instant.now().truncatedTo(ChronoUnit.SECONDS), Instant.now().truncatedTo(ChronoUnit.SECONDS), Instant.now().truncatedTo(ChronoUnit.SECONDS), @@ -406,4 +438,70 @@ public static void createIndex(RestClient client, String indexName, HttpEntity d null ); } + + public static GetResponse createGetResponse(ToXContentObject o, String id) throws IOException { + XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + + return new GetResponse( + new GetResult( + AnomalyDetector.ANOMALY_DETECTORS_INDEX, + MapperService.SINGLE_MAPPING_NAME, + id, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.bytes(content), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + } + + public static SearchResponse createSearchResponse(ToXContentObject o) throws IOException { + XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); + + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), + new InternalAggregations(Collections.emptyList()), + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + public static SearchResponse createEmptySearchResponse() throws IOException { + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f), + new InternalAggregations(Collections.emptyList()), + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java index 8b222e87..ff8c54b9 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java @@ -33,6 +33,7 @@ import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 679ae0cb..81242a49 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -135,6 +135,7 @@ public class ModelManagerTests { private String rcfModelId; private String thresholdModelId; private String checkpoint; + private int shingleSize; @Before public void setup() { @@ -156,6 +157,7 @@ public void setup() { minPreviewSize = 500; modelTtl = Duration.ofHours(1); checkpointInterval = Duration.ofHours(1); + shingleSize = 1; rcf = RandomCutForest.builder().dimensions(numFeatures).sampleSize(numSamples).numberOfTrees(numTrees).build(); @@ -185,7 +187,8 @@ public void setup() { thresholdingModelClass, minPreviewSize, modelTtl, - checkpointInterval + checkpointInterval, + shingleSize ) ); @@ -591,7 +594,7 @@ public void clear_deleteThresholdCheckpoint() { public void trainModel_putTrainedModels() { double[][] trainData = new Random().doubles().limit(100).mapToObj(d -> new double[] { d }).toArray(double[][]::new); doReturn(new SimpleEntry<>(1, 10)).when(modelManager).getPartitionedForestSizes(anyObject(), anyObject()); - + doReturn(asList("feature1")).when(anomalyDetector).getEnabledFeatureIds(); modelManager.trainModel(anomalyDetector, trainData); verify(checkpointDao).putModelCheckpoint(eq(modelManager.getRcfModelId(anomalyDetector.getDetectorId(), 0)), anyObject()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java index 5804216d..1cce2aae 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -34,7 +34,6 @@ import java.util.concurrent.ConcurrentHashMap; import com.amazon.opendistroforelasticsearch.ad.TestHelpers; -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; @@ -60,16 +59,17 @@ import org.junit.After; import org.junit.Before; -import com.amazon.randomcutforest.RandomCutForest; - public class ADStateManagerTests extends ESTestCase { private ADStateManager stateManager; private ModelManager modelManager; private Client client; + private ClientUtil clientUtil; private Clock clock; private Duration duration; private Throttler throttler; private ThreadPool context; + private AnomalyDetector detectorToCheck; + private Settings settings; @Override protected NamedXContentRegistry xContentRegistry() { @@ -82,10 +82,10 @@ protected NamedXContentRegistry xContentRegistry() { public void setUp() throws Exception { super.setUp(); modelManager = mock(ModelManager.class); - when(modelManager.getPartitionedForestSizes(any(RandomCutForest.class), any(String.class))) - .thenReturn(new SimpleImmutableEntry<>(2, 20)); + when(modelManager.getPartitionedForestSizes(any(AnomalyDetector.class))).thenReturn(new SimpleImmutableEntry<>(2, 20)); client = mock(Client.class); - Settings settings = Settings + clientUtil = mock(ClientUtil.class); + settings = Settings .builder() .put("opendistro.anomaly_detection.max_retry_for_unresponsive_node", 3) .put("opendistro.anomaly_detection.ad_mute_minutes", TimeValue.timeValueMinutes(10)) @@ -95,15 +95,7 @@ public void setUp() throws Exception { context = TestHelpers.createThreadPool(); throttler = new Throttler(clock); - stateManager = new ADStateManager( - client, - xContentRegistry(), - modelManager, - settings, - new ClientUtil(settings, client, throttler, context), - clock, - duration - ); + stateManager = new ADStateManager(client, xContentRegistry(), modelManager, settings, clientUtil, clock, duration); } @@ -114,12 +106,14 @@ public void tearDown() throws Exception { stateManager = null; modelManager = null; client = null; + clientUtil = null; + detectorToCheck = null; } @SuppressWarnings("unchecked") private String setupDetector(boolean responseExists) throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); - XContentBuilder content = detector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + detectorToCheck = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); + XContentBuilder content = detectorToCheck.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -130,8 +124,8 @@ private String setupDetector(boolean responseExists) throws IOException { if (args[0] instanceof GetRequest) { request = (GetRequest) args[0]; } - if (args[1] instanceof ActionListener) { - listener = (ActionListener) args[1]; + if (args[2] instanceof ActionListener) { + listener = (ActionListener) args[2]; } assertTrue(request != null && listener != null); @@ -141,7 +135,7 @@ private String setupDetector(boolean responseExists) throws IOException { new GetResult( AnomalyDetector.ANOMALY_DETECTORS_INDEX, MapperService.SINGLE_MAPPING_NAME, - detector.getDetectorId(), + detectorToCheck.getDetectorId(), UNASSIGNED_SEQ_NO, 0, -1, @@ -154,21 +148,17 @@ private String setupDetector(boolean responseExists) throws IOException { ); return null; - }).when(client).get(any(), any()); - return detector.getDetectorId(); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + return detectorToCheck.getDetectorId(); } public void testGetPartitionNumber() throws IOException, InterruptedException { String detectorId = setupDetector(true); - int partitionNumber = stateManager.getPartitionNumber(detectorId); + int partitionNumber = stateManager + .getPartitionNumber(detectorId, TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); assertEquals(2, partitionNumber); } - public void testGetResponseNotFound() throws IOException, InterruptedException { - String detectorId = setupDetector(false); - expectThrows(AnomalyDetectionException.class, () -> stateManager.getPartitionNumber(detectorId)); - } - public void testShouldMute() { String nodeId = "123"; assertTrue(!stateManager.isMuted(nodeId)); @@ -214,10 +204,29 @@ public void testMaintenancRemove() throws IOException { } public void testHasRunningQuery() throws IOException { + stateManager = new ADStateManager( + client, + xContentRegistry(), + modelManager, + settings, + new ClientUtil(settings, client, throttler, context), + clock, + duration + ); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), null); SearchRequest dummySearchRequest = new SearchRequest(); assertFalse(stateManager.hasRunningQuery(detector)); throttler.insertFilteredQuery(detector.getDetectorId(), dummySearchRequest); assertTrue(stateManager.hasRunningQuery(detector)); } + + public void testGetAnomalyDetector() throws IOException { + String detectorId = setupDetector(true); + stateManager + .getAnomalyDetector( + detectorId, + ActionListener.wrap(asDetector -> { assertEquals(detectorToCheck, asDetector.get()); }, exception -> assertTrue(false)) + ); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index 9832863f..f0a929c8 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Mockito.any; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyDouble; @@ -53,12 +53,12 @@ import java.util.function.Function; import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; -import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.common.exception.ClientException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; @@ -85,7 +85,6 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.ActionFilters; @@ -135,7 +134,6 @@ public class AnomalyResultTests extends AbstractADTest { private FeatureManager featureQuery; private ModelManager normalModelManager; private Client client; - private AnomalyDetectionIndices anomalyDetectionIndices; private AnomalyDetector detector; private HashRing hashRing; private IndexNameExpressionResolver indexNameResolver; @@ -171,7 +169,7 @@ public void setUp() throws Exception { clusterService = testNodes[0].clusterService; stateManager = mock(ADStateManager.class); // return 2 RCF partitions - when(stateManager.getPartitionNumber(any(String.class))).thenReturn(2); + when(stateManager.getPartitionNumber(any(String.class), any(AnomalyDetector.class))).thenReturn(2); when(stateManager.isMuted(any(String.class))).thenReturn(false); detector = mock(AnomalyDetector.class); @@ -184,19 +182,35 @@ public void setUp() throws Exception { userIndex.add("test*"); when(detector.getIndices()).thenReturn(userIndex); when(detector.getDetectorId()).thenReturn("testDetectorId"); - when(stateManager.getAnomalyDetector(any(String.class))).thenReturn(Optional.of(detector)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); hashRing = mock(HashRing.class); when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(clusterService.state().nodes().getLocalNode())); when(hashRing.build()).thenReturn(true); featureQuery = mock(FeatureManager.class); - when(featureQuery.getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong())) - .thenReturn(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 }))); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 }))); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + normalModelManager = mock(ModelManager.class); - when(normalModelManager.getThresholdingResult(any(String.class), any(String.class), anyDouble())) - .thenReturn(new ThresholdingResult(0, 1.0d)); - when(normalModelManager.getRcfResult(any(String.class), any(String.class), any(double[].class))) - .thenReturn(new RcfResult(0.2, 0, 100)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d)); + return null; + }).when(normalModelManager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0.2, 0, 100)); + return null; + }).when(normalModelManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(normalModelManager.combineRcfResults(any())).thenReturn(new CombinedRcfResult(0, 1.0d)); adID = "123"; rcfModelID = "123-rcf-1"; @@ -253,37 +267,6 @@ public void setupTestNodes(Settings settings) { runner = new ColdStartRunner(); } - @SuppressWarnings("unchecked") - public void setUpSavingAnomalyResultIndex(boolean anomalyResultIndexExists) throws IOException { - anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length >= 1); - - ActionListener listener = null; - - if (args[0] instanceof ActionListener) { - listener = (ActionListener) args[0]; - } - - assertTrue(listener != null); - - listener.onResponse(new CreateIndexResponse(true, true, AnomalyResult.ANOMALY_RESULT_INDEX) { - }); - - return null; - }).when(anomalyDetectionIndices).initAnomalyResultIndexDirectly(any()); - - when(anomalyDetectionIndices.doesAnomalyResultIndexExist()).thenReturn(anomalyResultIndexExists); - } - - public void setupInitResultIndexException(Class exceptionType) throws IOException { - anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - doThrow(exceptionType).when(anomalyDetectionIndices).initAnomalyResultIndexDirectly(any()); - - when(anomalyDetectionIndices.doesAnomalyResultIndexExist()).thenReturn(false); - } - @Override @After public final void tearDown() throws Exception { @@ -293,7 +276,6 @@ public final void tearDown() throws Exception { runner.shutDown(); runner = null; client = null; - anomalyDetectionIndices = null; super.tearDownLog4jForJUnit(); super.tearDown(); } @@ -309,7 +291,6 @@ private void assertException(PlainActionFuture listener, public void testNormal() throws IOException { - setUpSavingAnomalyResultIndex(false); // These constructors register handler in transport service new RCFResultTransportAction( new ActionFilters(Collections.emptySet()), @@ -322,17 +303,14 @@ public void testNormal() throws IOException { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -377,17 +355,14 @@ public Throwable noModelExceptionTemplate( AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, globalRunner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -441,10 +416,13 @@ public void testADExceptionWhenColdStart() { noModelExceptionTemplate(new AnomalyDetectionException(adID, ""), mockRunner, adID, error); } + @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringColdStart() { ModelManager rcfManager = mock(ModelManager.class); - doThrow(ResourceNotFoundException.class).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class)); + doThrow(ResourceNotFoundException.class) + .when(rcfManager) + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); ColdStartRunner mockRunner = mock(ColdStartRunner.class); @@ -458,17 +436,14 @@ public void testInsufficientCapacityExceptionDuringColdStart() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, mockRunner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -480,12 +455,13 @@ public void testInsufficientCapacityExceptionDuringColdStart() { assertException(listener, LimitExceededException.class); } + @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringRestoringModel() { ModelManager rcfManager = mock(ModelManager.class); doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) .when(rcfManager) - .getRcfResult(any(String.class), any(String.class), any(double[].class)); + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); // These constructors register handler in transport service new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService); @@ -494,17 +470,14 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -535,17 +508,14 @@ public void testThresholdException() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -569,17 +539,14 @@ public void testCircuitBreaker() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, breakerService, adStats ); @@ -636,17 +603,14 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), exceptionTransportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -692,24 +656,26 @@ public void testTemporaryThresholdNodeNotConnectedException() { nodeNotConnectedExceptionTemplate(false, true, 1); } + @SuppressWarnings("unchecked") public void testMute() { ADStateManager muteStateManager = mock(ADStateManager.class); when(muteStateManager.isMuted(any(String.class))).thenReturn(true); - when(muteStateManager.getAnomalyDetector(any(String.class))).thenReturn(Optional.of(detector)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(muteStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, muteStateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -722,7 +688,6 @@ public void testMute() { } public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { - setUpSavingAnomalyResultIndex(anomalyResultIndexExists); // These constructors register handler in transport service new RCFResultTransportAction( new ActionFilters(Collections.emptySet()), @@ -735,17 +700,14 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -890,22 +852,19 @@ public void testOnFailureNull() throws IOException { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, new ColdStartRunner(), - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, null, null, null, null, null, null, null, null, 0, 0, 0, new AtomicInteger(), null + null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null ); listener.onFailure(null); } @@ -915,17 +874,14 @@ public void testColdStartNoTrainingData() throws Exception { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -943,17 +899,14 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -968,32 +921,35 @@ enum FeatureTestMode { AD_EXCEPTION } + @SuppressWarnings("unchecked") public void featureTestTemplate(FeatureTestMode mode) { if (mode == FeatureTestMode.FEATURE_NOT_AVAILABLE) { - when(featureQuery.getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong())) - .thenReturn(new SinglePointFeatures(Optional.empty(), Optional.empty())); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } else if (mode == FeatureTestMode.ILLEGAL_STATE) { - doThrow(IllegalArgumentException.class).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong()); + doThrow(IllegalArgumentException.class) + .when(featureQuery) + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } else if (mode == FeatureTestMode.AD_EXCEPTION) { doThrow(AnomalyDetectionException.class) .when(featureQuery) - .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong()); + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -1069,17 +1025,14 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -1116,24 +1069,55 @@ public void testNullRCFResult() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, "123-rcf-0", null, "123", null, null, null, null, null, 0, 0, 0, new AtomicInteger(), null + null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null ); listener.onResponse(null); assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); } + + @SuppressWarnings("unchecked") + public void testAllFeaturesDisabled() { + // doThrow(IllegalArgumentException.class).when(featureQuery) + // .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onFailure(new IllegalArgumentException()); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + when(detector.getEnabledFeatureIds()).thenReturn(Collections.emptyList()); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + stateManager, + runner, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class, AnomalyResultTransportAction.ALL_FEATURES_DISABLED_ERR_MSG); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java index bc6c6141..a1e556dc 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java @@ -27,6 +27,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.Version; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.cluster.ClusterName; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java index 09f2ddee..7d60ffce 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java @@ -30,6 +30,7 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java index 95766caa..9084ee6a 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -33,6 +34,8 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; + +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; @@ -56,6 +59,7 @@ public class RCFResultTests extends ESTestCase { Gson gson = new GsonBuilder().create(); + @SuppressWarnings("unchecked") public void testNormal() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -75,7 +79,12 @@ public void testNormal() { manager, adCircuitBreakerService ); - when(manager.getRcfResult(any(String.class), any(String.class), any(double[].class))).thenReturn(new RcfResult(0, 0, 25)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0, 0, 25)); + return null; + }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + when(adCircuitBreakerService.isOpen()).thenReturn(false); final PlainActionFuture future = new PlainActionFuture<>(); @@ -87,6 +96,7 @@ public void testNormal() { assertEquals(25, response.getForestSize(), 0.001); } + @SuppressWarnings("unchecked") public void testExecutionException() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -106,7 +116,9 @@ public void testExecutionException() { manager, adCircuitBreakerService ); - doThrow(NullPointerException.class).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class)); + doThrow(NullPointerException.class) + .when(manager) + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(adCircuitBreakerService.isOpen()).thenReturn(false); final PlainActionFuture future = new PlainActionFuture<>(); @@ -172,6 +184,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { ); } + @SuppressWarnings("unchecked") public void testCircuitBreaker() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -191,7 +204,11 @@ public void testCircuitBreaker() { manager, breakerService ); - when(manager.getRcfResult(any(String.class), any(String.class), any(double[].class))).thenReturn(new RcfResult(0, 0, 25)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0, 0, 25)); + return null; + }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(breakerService.isOpen()).thenReturn(true); final PlainActionFuture future = new PlainActionFuture<>(); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java index 4c2ca200..4a992126 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Collections; @@ -31,6 +31,8 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingResult; + +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; @@ -50,6 +52,7 @@ public class ThresholdResultTests extends ESTestCase { + @SuppressWarnings("unchecked") public void testNormal() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -63,7 +66,11 @@ public void testNormal() { ModelManager manager = mock(ModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); - when(manager.getThresholdingResult(any(String.class), any(String.class), anyDouble())).thenReturn(new ThresholdingResult(0, 1.0d)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d)); + return null; + }).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); final PlainActionFuture future = new PlainActionFuture<>(); ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2); @@ -74,6 +81,7 @@ public void testNormal() { assertEquals(1, response.getConfidence(), 0.001); } + @SuppressWarnings("unchecked") public void testExecutionException() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -87,7 +95,9 @@ public void testExecutionException() { ModelManager manager = mock(ModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); - doThrow(NullPointerException.class).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble()); + doThrow(NullPointerException.class) + .when(manager) + .getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); final PlainActionFuture future = new PlainActionFuture<>(); ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2);