From 0c3305000ffe6e61c7c2104314a555e69d060023 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 14 Apr 2020 11:52:20 -0700 Subject: [PATCH] Use callbacks and bug fix (#83) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use callbacks and bug fix This PR includes the following changes: 1. remove classes that are not needed in jacocoExclusions since we have enough coverage for those classes. 2. Use ClientUtil instead of Elasticsearch’s client in AD job runner 3. Use one function to get the number of partitioned forests. Previously, we have redundant code in both ModelManager and ADStateManager. 4. Change ADStateManager.getAnomalyDetector to use callback. 5. Change AnomalyResultTransportAction to use callback to get features. 6. Add in AnomalyResultTransportAction to handle the case where all features have been disabled, and users' index does not exist. 7. Change get RCF and threshold result methods to use callback and add exception handling of IndexNotFoundException due to the change. Previously, getting RCF and threshold result methods won’t throw IndexNotFoundException. 8. Remove unused fields in StopDetectorTransportAction and AnomalyResultTransportAction 9. Unwrap EsRejectedExecutionException as it can be nested inside RemoteTransportException. Previously, we would not recognize EsRejectedExecutionException and thus miss anomaly results write retrying. 10. Add error in anomaly result schema.11. Fix broken tests due to my changes. Testing done: 1. unit/integration tests pass 2. do end-to-end testing and make sure my fix achieves the purpose  * timeout issue is gone  * when all features have been disabled or index does not exist, we will retry a few more times and disable AD jobs. --- build.gradle | 18 +- .../ad/AnomalyDetectorJobRunner.java | 34 ++-- .../ad/AnomalyDetectorPlugin.java | 7 +- .../ad/AnomalyDetectorRunner.java | 1 + .../ad/ml/ModelManager.java | 62 ++++-- .../ad/transport/ADStateManager.java | 78 +++----- .../AnomalyResultTransportAction.java | 137 +++++++------ .../transport/RCFResultTransportAction.java | 21 +- .../StopDetectorTransportAction.java | 6 +- .../ThresholdResultTransportAction.java | 19 +- .../handler/AnomalyResultHandler.java | 52 +++-- .../resources/mappings/anomaly-results.json | 3 + .../ad/AnomalyDetectorJobRunnerTests.java | 25 ++- .../cluster/ADClusterEventListenerTests.java | 1 + .../ad/ml/ModelManagerTests.java | 9 +- .../ad/transport/ADStateManagerTests.java | 67 ++++--- .../ad/transport/AnomalyResultTests.java | 186 ++++++++---------- .../transport/CronTransportActionTests.java | 1 + .../DeleteModelTransportActionTests.java | 1 + .../ad/transport/RCFResultTests.java | 25 ++- .../ad/transport/ThresholdResultTests.java | 18 +- 21 files changed, 429 insertions(+), 342 deletions(-) diff --git a/build.gradle b/build.gradle index 9c6fef64..68607e52 100644 --- a/build.gradle +++ b/build.gradle @@ -255,23 +255,17 @@ List jacocoExclusions = [ 'com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException', 'com.amazon.opendistroforelasticsearch.ad.util.ClientUtil', - 'com.amazon.opendistroforelasticsearch.ad.ml.*', - 'com.amazon.opendistroforelasticsearch.ad.feature.*', - 'com.amazon.opendistroforelasticsearch.ad.dataprocessor.*', - 'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner', - 'com.amazon.opendistroforelasticsearch.ad.resthandler.RestGetAnomalyResultAction', - 'com.amazon.opendistroforelasticsearch.ad.metrics.MetricFactory', - 'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices', - 'com.amazon.opendistroforelasticsearch.ad.transport.ForwardAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.ForwardTransportAction', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorAction', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorRequest', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorResponse', 'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorTransportAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest', 'com.amazon.opendistroforelasticsearch.ad.transport.DeleteDetectorAction', - 'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils' + 'com.amazon.opendistroforelasticsearch.ad.transport.CronTransportAction', + 'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest', + 'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction', + 'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner', + 'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices', + 'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils', ] jacocoTestCoverageVerification { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java index e59bf605..82f6d502 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunner.java @@ -28,6 +28,7 @@ import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultResponse; import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction; import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.JobExecutionContext; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.LockModel; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.ScheduledJobParameter; @@ -39,7 +40,9 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.get.GetRequest; +import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.settings.Settings; @@ -71,6 +74,7 @@ public class AnomalyDetectorJobRunner implements ScheduledJobRunner { private Settings settings; private int maxRetryForEndRunException; private Client client; + private ClientUtil clientUtil; private ThreadPool threadPool; private AnomalyResultHandler anomalyResultHandler; private ConcurrentHashMap detectorEndRunExceptionCount; @@ -97,6 +101,10 @@ public void setClient(Client client) { this.client = client; } + public void setClientUtil(ClientUtil clientUtil) { + this.clientUtil = clientUtil; + } + public void setThreadPool(ThreadPool threadPool) { this.threadPool = threadPool; } @@ -258,7 +266,7 @@ protected void handleAdException( ) { String detectorId = jobParameter.getName(); if (exception instanceof EndRunException) { - log.error("EndRunException happened when executed anomaly result action for " + detectorId, exception); + log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception); if (((EndRunException) exception).isEndNow()) { // Stop AD job if EndRunException shows we should end job now. @@ -349,9 +357,8 @@ private void stopAdJob(String detectorId) { try { GetRequest getRequest = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX).id(detectorId); - client.get(getRequest, ActionListener.wrap(response -> { + clientUtil.asyncRequest(getRequest, client::get, ActionListener.wrap(response -> { if (response.isExists()) { - String s = response.getSourceAsString(); try ( XContentParser parser = XContentType.JSON .xContent() @@ -374,14 +381,19 @@ private void stopAdJob(String detectorId) { .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)) .id(detectorId); - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - if (indexResponse != null - && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { - log.info("AD Job was disabled by JobRunner for " + detectorId); - } else { - log.warn("Failed to disable AD job for " + detectorId); - } - }, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception))); + clientUtil + .asyncRequest( + indexRequest, + client::index, + ActionListener.wrap(indexResponse -> { + if (indexResponse != null + && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { + log.info("AD Job was disabled by JobRunner for " + detectorId); + } else { + log.warn("Failed to disable AD job for " + detectorId); + } + }, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception)) + ); } } catch (IOException e) { log.error("JobRunner failed to stop detector job " + detectorId, e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 5f93e936..9744ed76 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -143,6 +143,7 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip private ThreadPool threadPool; private IndexNameExpressionResolver indexNameExpressionResolver; private ADStats adStats; + private ClientUtil clientUtil; static { SpecialPermission.check(); @@ -174,6 +175,7 @@ public List getRestHandlers( ); AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); jobRunner.setClient(client); + jobRunner.setClientUtil(clientUtil); jobRunner.setThreadPool(threadPool); jobRunner.setAnomalyResultHandler(anomalyResultHandler); jobRunner.setSettings(settings); @@ -237,7 +239,7 @@ public Collection createComponents( Settings settings = environment.settings(); Clock clock = Clock.systemUTC(); Throttler throttler = new Throttler(clock); - ClientUtil clientUtil = new ClientUtil(settings, client, throttler, threadPool); + this.clientUtil = new ClientUtil(settings, client, throttler, threadPool); IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService); anomalyDetectionIndices = new AnomalyDetectionIndices(client, clusterService, threadPool, settings, clientUtil); this.clusterService = clusterService; @@ -272,7 +274,8 @@ public Collection createComponents( HybridThresholdingModel.class, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.SHINGLE_SIZE ); HashRing hashRing = new HashRing(clusterService, clock, settings); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java index 5998b08f..5a613b50 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRunner.java @@ -58,6 +58,7 @@ public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureMa * @param startTime detection period start time * @param endTime detection period end time * @param listener handle anomaly result + * @throws IOException - if a user gives wrong query input when defining a detector */ public void executeDetector(AnomalyDetector detector, Instant startTime, Instant endTime, ActionListener> listener) throws IOException { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index dbf74b16..e334645e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -108,6 +109,7 @@ public String getName() { private final CheckpointDao checkpointDao; private final Gson gson; private final Clock clock; + private final int shingleSize; // A tree of N samples has 2N nodes, with one bounding box for each node. private static final long BOUNDING_BOXES = 2L; @@ -160,7 +162,8 @@ public ModelManager( Class thresholdingModelClass, int minPreviewSize, Duration modelTtl, - Duration checkpointInterval + Duration checkpointInterval, + int shingleSize ) { this.clusterService = clusterService; @@ -188,6 +191,7 @@ public ModelManager( this.forests = new ConcurrentHashMap<>(); this.thresholds = new ConcurrentHashMap<>(); + this.shingleSize = shingleSize; } /** @@ -272,6 +276,36 @@ public Entry getPartitionedForestSizes(RandomCutForest forest, return new SimpleImmutableEntry<>(numPartitions, forestSize); } + /** + * Construct a RCF model and then partition it by forest size. + * + * A RCF model is constructed based on the number of input features. + * + * Then a RCF model is first partitioned into desired size based on heap. + * If there are more partitions than the number of nodes in the cluster, + * the model is partitioned by the number of nodes and verified to + * ensure the size of a partition does not exceed the max size limit based on heap. + * + * @param detector detector object + * @return a pair of number of partitions and size of a parition (number of trees) + * @throws LimitExceededException when there is no sufficient resouce available + */ + public Entry getPartitionedForestSizes(AnomalyDetector detector) { + String detectorId = detector.getDetectorId(); + int rcfNumFeatures = detector.getEnabledFeatureIds().size() * shingleSize; + return getPartitionedForestSizes( + RandomCutForest + .builder() + .dimensions(rcfNumFeatures) + .sampleSize(rcfNumSamplesInTree) + .numberOfTrees(rcfNumTrees) + .outputAfter(rcfNumSamplesInTree) + .parallelExecutionEnabled(false) + .build(), + detectorId + ); + } + /** * Gets the estimated size of a RCF model. * @@ -542,20 +576,22 @@ public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints) { if (dataPoints.length == 0 || dataPoints[0].length == 0) { throw new IllegalArgumentException("Data points must not be empty."); } + if (dataPoints[0].length != anomalyDetector.getEnabledFeatureIds().size() * shingleSize) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Feature dimension is not correct, we expect %s but get %d", + anomalyDetector.getEnabledFeatureIds().size() * shingleSize, + dataPoints[0].length + ) + ); + } int rcfNumFeatures = dataPoints[0].length; // Create partitioned RCF models - Entry partitionResults = getPartitionedForestSizes( - RandomCutForest - .builder() - .dimensions(rcfNumFeatures) - .sampleSize(rcfNumSamplesInTree) - .numberOfTrees(rcfNumTrees) - .outputAfter(rcfNumSamplesInTree) - .parallelExecutionEnabled(false) - .build(), - anomalyDetector.getDetectorId() - ); + Entry partitionResults = getPartitionedForestSizes(anomalyDetector); + int numForests = partitionResults.getKey(); int forestSize = partitionResults.getValue(); double[] scores = new double[dataPoints.length]; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java index eb57123c..4e993c50 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -22,19 +22,18 @@ import java.time.Instant; import java.util.Map; import java.util.Optional; -import java.util.Random; import java.util.AbstractMap.SimpleEntry; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; -import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.client.Client; @@ -44,8 +43,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import com.amazon.randomcutforest.RandomCutForest; - /** * ADStateManager is used by transport layer to manage AnomalyDetector object * and the number of partitions for a detector id. @@ -56,7 +53,6 @@ public class ADStateManager { private ConcurrentHashMap> currentDetectors; private ConcurrentHashMap> partitionNumber; private Client client; - private Random random; private ModelManager modelManager; private NamedXContentRegistry xContentRegistry; private ClientUtil clientUtil; @@ -77,7 +73,6 @@ public ADStateManager( ) { this.currentDetectors = new ConcurrentHashMap<>(); this.client = client; - this.random = new Random(); this.modelManager = modelManager; this.xContentRegistry = xContentRegistry; this.partitionNumber = new ConcurrentHashMap<>(); @@ -91,67 +86,58 @@ public ADStateManager( /** * Get the number of RCF model's partition number for detector adID * @param adID detector id + * @param detector object * @return the number of RCF model's partition number for adID - * @throws InterruptedException when we cannot get anomaly detector object for adID before timeout * @throws LimitExceededException when there is no sufficient resource available */ - public int getPartitionNumber(String adID) throws InterruptedException { + public int getPartitionNumber(String adID, AnomalyDetector detector) { Entry partitonAndTime = partitionNumber.get(adID); if (partitonAndTime != null) { partitonAndTime.setValue(clock.instant()); return partitonAndTime.getKey(); } - Optional detector = getAnomalyDetector(adID); - if (!detector.isPresent()) { - throw new AnomalyDetectionException(adID, "AnomalyDetector is not found"); - } - - RandomCutForest forest = RandomCutForest - .builder() - .dimensions(detector.get().getFeatureAttributes().size() * AnomalyDetectorSettings.SHINGLE_SIZE) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) - .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) - .parallelExecutionEnabled(false) - .build(); - int partitionNum = modelManager.getPartitionedForestSizes(forest, adID).getKey(); + int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey(); partitionNumber.putIfAbsent(adID, new SimpleEntry<>(partitionNum, clock.instant())); return partitionNum; } - public Optional getAnomalyDetector(String adID) { + public void getAnomalyDetector(String adID, ActionListener> listener) { Entry detectorAndTime = currentDetectors.get(adID); if (detectorAndTime != null) { detectorAndTime.setValue(clock.instant()); - return Optional.of(detectorAndTime.getKey()); + listener.onResponse(Optional.of(detectorAndTime.getKey())); + return; } GetRequest request = new GetRequest(AnomalyDetector.ANOMALY_DETECTORS_INDEX, adID); - Optional getResponse = clientUtil.timedRequest(request, LOG, client::get); - - return onGetResponse(getResponse, adID); + clientUtil.asyncRequest(request, client::get, onGetResponse(adID, listener)); } - private Optional onGetResponse(Optional asResponse, String adID) { - if (!asResponse.isPresent() || !asResponse.get().isExists()) { - return Optional.empty(); - } + private ActionListener onGetResponse(String adID, ActionListener> listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Optional.empty()); + return; + } - GetResponse response = asResponse.get(); - String xc = response.getSourceAsString(); - LOG.debug("Fetched anomaly detector: {}", xc); - - try (XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc)) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); - currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant())); - return Optional.of(detector); - } catch (Exception t) { - LOG.error("Fail to parse detector {}", adID); - LOG.error("Stack trace:", t); - return Optional.empty(); - } + String xc = response.getSourceAsString(); + LOG.info("Fetched anomaly detector: {}", xc); + + try ( + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); + currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant())); + listener.onResponse(Optional.of(detector)); + } catch (Exception t) { + LOG.error("Fail to parse detector {}", adID); + LOG.error("Stack trace:", t); + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); } /** diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 4d3f62a9..f130ab9c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -32,9 +32,9 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.common.exception.ResourceNotFoundException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.feature.SinglePointFeatures; -import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; import com.amazon.opendistroforelasticsearch.ad.ml.rcf.CombinedRcfResult; @@ -45,6 +45,7 @@ import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; import com.amazon.opendistroforelasticsearch.ad.util.ColdStartRunner; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -54,11 +55,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -69,12 +68,12 @@ import org.elasticsearch.common.io.stream.NotSerializableExceptionWrapper; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexNotFoundException; public class AnomalyResultTransportAction extends HandledTransportAction { @@ -83,6 +82,7 @@ public class AnomalyResultTransportAction extends HandledTransportAction getFeatureData(double[] currentFeature, AnomalyDetecto * + unknown prediction error * * Known cause of EndRunException with endNow returning true: - * + anomaly detector is not available - * + a models' memory size reached limit + * + a model's memory size reached limit * + models' total memory size reached limit + * + Having trouble querying feature data due to + * * index does not exist + * * all features have been disabled + * + anomaly detector is not available + * * Known cause of InternalFailure: * + threshold model node is not available * + cluster read/write is blocked - * + interrupted while waiting for rcf/threshold model nodes' responses * + cold start hasn't been finished * + fail to get all of rcf model nodes' responses * + fail to get threshold model node's response * + RCF/Threshold model node failing to get checkpoint to restore model before timeout + * + Detection is throttle because previous detection query is running * */ @Override @@ -225,7 +214,18 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } try { - Optional detector = stateManager.getAnomalyDetector(adID); + stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); + } catch (Exception ex) { + handleExecuteException(ex, listener, adID); + } + } + + private ActionListener> onGetDetector( + ActionListener listener, + String adID, + AnomalyResultRequest request + ) { + return ActionListener.wrap(detector -> { if (!detector.isPresent()) { listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); return; @@ -233,13 +233,15 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< AnomalyDetector anomalyDetector = detector.get(); String thresholdModelID = modelManager.getThresholdModelId(adID); - Optional thresholdNode = hashRing.getOwningNode(thresholdModelID); - if (!thresholdNode.isPresent()) { + Optional asThresholdNode = hashRing.getOwningNode(thresholdModelID); + if (!asThresholdNode.isPresent()) { listener.onFailure(new InternalFailure(adID, "Threshold model node is not available.")); return; } - if (!shouldStart(listener, adID, detector.get(), thresholdNode.get().getId(), thresholdModelID)) { + DiscoveryNode thresholdNode = asThresholdNode.get(); + + if (!shouldStart(listener, adID, anomalyDetector, thresholdNode.getId(), thresholdModelID)) { return; } @@ -250,12 +252,31 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< long dataStartTime = request.getStart() - delayMillis; long dataEndTime = request.getEnd() - delayMillis; - SinglePointFeatures featureOptional = featureManager.getCurrentFeatures(anomalyDetector, dataStartTime, dataEndTime); + featureManager + .getCurrentFeatures( + anomalyDetector, + dataStartTime, + dataEndTime, + onFeatureResponse(adID, anomalyDetector, listener, thresholdModelID, thresholdNode, dataStartTime, dataEndTime) + ); + }, exception -> handleExecuteException(exception, listener, adID)); + + } + private ActionListener onFeatureResponse( + String adID, + AnomalyDetector detector, + ActionListener listener, + String thresholdModelID, + DiscoveryNode thresholdNode, + long dataStartTime, + long dataEndTime + ) { + return ActionListener.wrap(featureOptional -> { List featureInResponse = null; if (featureOptional.getUnprocessedFeatures().isPresent()) { - featureInResponse = getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector.get()); + featureInResponse = getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); } if (!featureOptional.getProcessedFeatures().isPresent()) { @@ -293,7 +314,7 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< // Can throw LimitExceededException when a single partition is more than X% of heap memory. // Compute this number once and the value won't change unless the coordinating AD node for an // detector changes or the cluster size changes. - int rcfPartitionNum = stateManager.getPartitionNumber(adID); + int rcfPartitionNum = stateManager.getPartitionNumber(adID, detector); List rcfResults = new ArrayList<>(); @@ -326,8 +347,6 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< thresholdModelID, thresholdNode, featureInResponse, - dataStartTime, - dataEndTime, rcfPartitionNum, responseCount, adID @@ -342,9 +361,15 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) ); } - } catch (Exception ex) { - handleExecuteException(ex, listener, adID); - } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true)); + } else if (exception instanceof IllegalArgumentException && detector.getEnabledFeatureIds().isEmpty()) { + listener.onFailure(new EndRunException(adID, ALL_FEATURES_DISABLED_ERR_MSG, true)); + } else { + handleExecuteException(exception, listener, adID); + } + }); } /** @@ -390,7 +415,9 @@ private void findException(Throwable cause, String adID, AtomicReference previousException = globalRunner.fetchException(adID); @@ -468,13 +495,11 @@ class RCFActionListener implements ActionListener { private String modelID; private AtomicReference failure; private String rcfNodeID; - private Optional detector; + private AnomalyDetector detector; private ActionListener listener; private String thresholdModelID; - private Optional thresholdNode; + private DiscoveryNode thresholdNode; private List featureInResponse; - private long startTime; - private long endTime; private int nodeCount; private final AtomicInteger responseCount; private final String adID; @@ -484,13 +509,11 @@ class RCFActionListener implements ActionListener { String modelID, AtomicReference failure, String rcfNodeID, - Optional detector, + AnomalyDetector detector, ActionListener listener, String thresholdModelID, - Optional thresholdNode, + DiscoveryNode thresholdNode, List features, - long startTime, - long endTime, int nodeCount, AtomicInteger responseCount, String adID @@ -504,8 +527,6 @@ class RCFActionListener implements ActionListener { this.thresholdModelID = thresholdModelID; this.featureInResponse = features; this.failure = failure; - this.startTime = startTime; - this.endTime = endTime; this.nodeCount = nodeCount; this.responseCount = responseCount; this.adID = adID; @@ -544,7 +565,7 @@ public void onFailure(Exception e) { private void handleRCFResults() { try { - if (coldStartIfNoModel(failure, detector.get()) || rcfResults.isEmpty()) { + if (coldStartIfNoModel(failure, detector) || rcfResults.isEmpty()) { listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); return; } @@ -554,7 +575,7 @@ private void handleRCFResults() { final AtomicReference anomalyResultResponse = new AtomicReference<>(); - String thresholdNodeId = thresholdNode.get().getId(); + String thresholdNodeId = thresholdNode.getId(); LOG.info("Sending threshold request to {} for model {}", thresholdNodeId, thresholdModelID); ThresholdActionListener thresholdListener = new ThresholdActionListener( anomalyResultResponse, @@ -563,15 +584,12 @@ private void handleRCFResults() { thresholdNodeId, detector, combinedResult, - featureInResponse, listener, - startTime, - endTime, adID ); transportService .sendRequest( - thresholdNode.get(), + thresholdNode, ThresholdResultAction.NAME, new ThresholdResultRequest(adID, thresholdModelID, combinedScore), option, @@ -590,11 +608,8 @@ class ThresholdActionListener implements ActionListener private AtomicReference failure; private String thresholdNodeID; private ActionListener listener; - private Optional detector; + private AnomalyDetector detector; private CombinedRcfResult combinedResult; - private List featureInResponse; - private long startTime; - private long endTime; private String adID; ThresholdActionListener( @@ -602,12 +617,9 @@ class ThresholdActionListener implements ActionListener List features, String modelID, String thresholdNodeID, - Optional detector, + AnomalyDetector detector, CombinedRcfResult combinedResult, - List featureInResponse, ActionListener listener, - long startTime, - long endTime, String adID ) { this.anomalyResultResponse = anomalyResultResponse; @@ -616,9 +628,6 @@ class ThresholdActionListener implements ActionListener this.thresholdNodeID = thresholdNodeID; this.detector = detector; this.combinedResult = combinedResult; - this.featureInResponse = featureInResponse; - this.startTime = startTime; - this.endTime = endTime; this.failure = new AtomicReference(); this.listener = listener; this.adID = adID; @@ -650,7 +659,7 @@ public void onFailure(Exception e) { private void handleThresholdResult() { try { - if (coldStartIfNoModel(failure, detector.get())) { + if (coldStartIfNoModel(failure, detector)) { listener.onFailure(new InternalFailure(adID, NO_MODEL_ERR_MSG)); return; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java index 41e11052..382c632c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; -import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -57,8 +57,21 @@ protected void doExecute(Task task, RCFResultRequest request, ActionListener listener + .onResponse(new RCFResultResponse(result.getScore(), result.getConfidence(), result.getForestSize())), + exception -> { + LOG.warn(exception); + listener.onFailure(exception); + } + ) + ); } catch (Exception e) { LOG.error(e); listener.onFailure(e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java index fb140a8a..d434837e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/StopDetectorTransportAction.java @@ -15,7 +15,6 @@ package com.amazon.opendistroforelasticsearch.ad.transport; -import com.amazon.opendistroforelasticsearch.ad.cluster.DeleteDetector; import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -39,20 +38,17 @@ public class StopDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); @@ -42,8 +42,17 @@ protected void doExecute(Task task, ThresholdResultRequest request, ActionListen try { LOG.info("Serve threshold request for {}", request.getModelID()); - ThresholdingResult result = manager.getThresholdingResult(request.getAdID(), request.getModelID(), request.getRCFScore()); - listener.onResponse(new ThresholdResultResponse(result.getGrade(), result.getConfidence())); + manager + .getThresholdingResult( + request.getAdID(), + request.getModelID(), + request.getRCFScore(), + ActionListener + .wrap( + result -> listener.onResponse(new ThresholdResultResponse(result.getGrade(), result.getConfidence())), + exception -> listener.onFailure(exception) + ) + ); } catch (Exception e) { LOG.error(e); listener.onFailure(e); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java index 742a0689..e9a970cc 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/handler/AnomalyResultHandler.java @@ -172,35 +172,27 @@ private void saveDetectorResult(AnomalyResult anomalyResult) { } void saveDetectorResult(IndexRequest indexRequest, String context, Iterator backoff) { - client - .index( - indexRequest, - ActionListener - .wrap( - response -> LOG.debug(SUCCESS_SAVING_MSG + context), - exception -> { - // Elasticsearch has a thread pool and a queue for write per node. A thread - // pool will have N number of workers ready to handle the requests. When a - // request comes and if a worker is free , this is handled by the worker. Now by - // default the number of workers is equal to the number of cores on that CPU. - // When the workers are full and there are more write requests, the request - // will go to queue. The size of queue is also limited. If by default size is, - // say, 200 and if there happens more parallel requests than this, then those - // requests would be rejected as you can see EsRejectedExecutionException. - // So EsRejectedExecutionException is the way that Elasticsearch tells us that - // it cannot keep up with the current indexing rate. - // When it happens, we should pause indexing a bit before trying again, ideally - // with randomized exponential backoff. - if (!(exception instanceof EsRejectedExecutionException) || !backoff.hasNext()) { - LOG.error(FAIL_TO_SAVE_ERR_MSG + context, exception); - } else { - TimeValue nextDelay = backoff.next(); - LOG.warn(RETRY_SAVING_ERR_MSG + context, exception); - threadPool - .schedule(() -> saveDetectorResult(indexRequest, context, backoff), nextDelay, ThreadPool.Names.SAME); - } - } - ) - ); + client.index(indexRequest, ActionListener.wrap(response -> LOG.debug(SUCCESS_SAVING_MSG + context), exception -> { + // Elasticsearch has a thread pool and a queue for write per node. A thread + // pool will have N number of workers ready to handle the requests. When a + // request comes and if a worker is free , this is handled by the worker. Now by + // default the number of workers is equal to the number of cores on that CPU. + // When the workers are full and there are more write requests, the request + // will go to queue. The size of queue is also limited. If by default size is, + // say, 200 and if there happens more parallel requests than this, then those + // requests would be rejected as you can see EsRejectedExecutionException. + // So EsRejectedExecutionException is the way that Elasticsearch tells us that + // it cannot keep up with the current indexing rate. + // When it happens, we should pause indexing a bit before trying again, ideally + // with randomized exponential backoff. + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (!(cause instanceof EsRejectedExecutionException) || !backoff.hasNext()) { + LOG.error(FAIL_TO_SAVE_ERR_MSG + context, cause); + } else { + TimeValue nextDelay = backoff.next(); + LOG.warn(RETRY_SAVING_ERR_MSG + context, cause); + threadPool.schedule(() -> saveDetectorResult(indexRequest, context, backoff), nextDelay, ThreadPool.Names.SAME); + } + })); } } diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 2ad310d1..80ee69e4 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -45,6 +45,9 @@ "execution_end_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "error": { + "type": "text" } } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java index 0f93ed59..45b23ea0 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorJobRunnerTests.java @@ -19,6 +19,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.JobExecutionContext; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.LockModel; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.ScheduledJobParameter; @@ -26,6 +27,7 @@ import com.amazon.opendistroforelasticsearch.jobscheduler.spi.schedule.Schedule; import com.amazon.opendistroforelasticsearch.jobscheduler.spi.utils.LockService; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; @@ -74,6 +76,10 @@ public class AnomalyDetectorJobRunnerTests extends AbstractADTest { @Mock private Client client; + + @Mock + private ClientUtil clientUtil; + @Mock private ClusterService clusterService; @@ -119,6 +125,7 @@ public void setup() throws Exception { doReturn(executorService).when(mockedThreadPool).executor(anyString()); runner.setThreadPool(mockedThreadPool); runner.setClient(client); + runner.setClientUtil(clientUtil); runner.setAnomalyResultHandler(anomalyResultHandler); setUpJobParameter(); @@ -214,7 +221,7 @@ public void testRunAdJobWithEndRunExceptionNow() { public void testRunAdJobWithEndRunExceptionNowAndExistingAdJob() { testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, true); verify(anomalyResultHandler).indexAnomalyResult(any()); - verify(client).index(any(), any()); + verify(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); assertTrue(testAppender.containsMessage("AD Job was disabled by JobRunner for")); } @@ -222,7 +229,7 @@ public void testRunAdJobWithEndRunExceptionNowAndExistingAdJob() { public void testRunAdJobWithEndRunExceptionNowAndExistingAdJobAndIndexException() { testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, false); verify(anomalyResultHandler).indexAnomalyResult(any()); - verify(client).index(any(), any()); + verify(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); assertTrue(testAppender.containsMessage("Failed to disable AD job for")); } @@ -256,7 +263,7 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); GetResponse response = new GetResponse( new GetResult( AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, @@ -286,11 +293,11 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b listener.onResponse(response); return null; - }).when(client).get(any(), any()); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); doAnswer(invocation -> { IndexRequest request = invocation.getArgument(0); - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); ShardId shardId = new ShardId(new Index(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, randomAlphaOfLength(10)), 0); if (disableSuccessfully) { listener.onResponse(new IndexResponse(shardId, randomAlphaOfLength(10), request.id(), 1, 1, 1, true)); @@ -298,7 +305,7 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b listener.onResponse(null); } return null; - }).when(client).index(any(), any()); + }).when(clientUtil).asyncRequest(any(IndexRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); } @@ -309,10 +316,10 @@ public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("test")); return null; - }).when(client).get(any(), any()); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); verify(anomalyResultHandler).indexAnomalyResult(any()); @@ -324,7 +331,7 @@ public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); - doThrow(new RuntimeException("fail to get AD job")).when(client).get(any(), any()); + doThrow(new RuntimeException("fail to get AD job")).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any()); runner.handleAdException(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), exception); verify(anomalyResultHandler).indexAnomalyResult(any()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java index 8b222e87..ff8c54b9 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/cluster/ADClusterEventListenerTests.java @@ -33,6 +33,7 @@ import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 679ae0cb..81242a49 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -135,6 +135,7 @@ public class ModelManagerTests { private String rcfModelId; private String thresholdModelId; private String checkpoint; + private int shingleSize; @Before public void setup() { @@ -156,6 +157,7 @@ public void setup() { minPreviewSize = 500; modelTtl = Duration.ofHours(1); checkpointInterval = Duration.ofHours(1); + shingleSize = 1; rcf = RandomCutForest.builder().dimensions(numFeatures).sampleSize(numSamples).numberOfTrees(numTrees).build(); @@ -185,7 +187,8 @@ public void setup() { thresholdingModelClass, minPreviewSize, modelTtl, - checkpointInterval + checkpointInterval, + shingleSize ) ); @@ -591,7 +594,7 @@ public void clear_deleteThresholdCheckpoint() { public void trainModel_putTrainedModels() { double[][] trainData = new Random().doubles().limit(100).mapToObj(d -> new double[] { d }).toArray(double[][]::new); doReturn(new SimpleEntry<>(1, 10)).when(modelManager).getPartitionedForestSizes(anyObject(), anyObject()); - + doReturn(asList("feature1")).when(anomalyDetector).getEnabledFeatureIds(); modelManager.trainModel(anomalyDetector, trainData); verify(checkpointDao).putModelCheckpoint(eq(modelManager.getRcfModelId(anomalyDetector.getDetectorId(), 0)), anyObject()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java index 5804216d..1cce2aae 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStateManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -34,7 +34,6 @@ import java.util.concurrent.ConcurrentHashMap; import com.amazon.opendistroforelasticsearch.ad.TestHelpers; -import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; @@ -60,16 +59,17 @@ import org.junit.After; import org.junit.Before; -import com.amazon.randomcutforest.RandomCutForest; - public class ADStateManagerTests extends ESTestCase { private ADStateManager stateManager; private ModelManager modelManager; private Client client; + private ClientUtil clientUtil; private Clock clock; private Duration duration; private Throttler throttler; private ThreadPool context; + private AnomalyDetector detectorToCheck; + private Settings settings; @Override protected NamedXContentRegistry xContentRegistry() { @@ -82,10 +82,10 @@ protected NamedXContentRegistry xContentRegistry() { public void setUp() throws Exception { super.setUp(); modelManager = mock(ModelManager.class); - when(modelManager.getPartitionedForestSizes(any(RandomCutForest.class), any(String.class))) - .thenReturn(new SimpleImmutableEntry<>(2, 20)); + when(modelManager.getPartitionedForestSizes(any(AnomalyDetector.class))).thenReturn(new SimpleImmutableEntry<>(2, 20)); client = mock(Client.class); - Settings settings = Settings + clientUtil = mock(ClientUtil.class); + settings = Settings .builder() .put("opendistro.anomaly_detection.max_retry_for_unresponsive_node", 3) .put("opendistro.anomaly_detection.ad_mute_minutes", TimeValue.timeValueMinutes(10)) @@ -95,15 +95,7 @@ public void setUp() throws Exception { context = TestHelpers.createThreadPool(); throttler = new Throttler(clock); - stateManager = new ADStateManager( - client, - xContentRegistry(), - modelManager, - settings, - new ClientUtil(settings, client, throttler, context), - clock, - duration - ); + stateManager = new ADStateManager(client, xContentRegistry(), modelManager, settings, clientUtil, clock, duration); } @@ -114,12 +106,14 @@ public void tearDown() throws Exception { stateManager = null; modelManager = null; client = null; + clientUtil = null; + detectorToCheck = null; } @SuppressWarnings("unchecked") private String setupDetector(boolean responseExists) throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); - XContentBuilder content = detector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + detectorToCheck = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); + XContentBuilder content = detectorToCheck.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -130,8 +124,8 @@ private String setupDetector(boolean responseExists) throws IOException { if (args[0] instanceof GetRequest) { request = (GetRequest) args[0]; } - if (args[1] instanceof ActionListener) { - listener = (ActionListener) args[1]; + if (args[2] instanceof ActionListener) { + listener = (ActionListener) args[2]; } assertTrue(request != null && listener != null); @@ -141,7 +135,7 @@ private String setupDetector(boolean responseExists) throws IOException { new GetResult( AnomalyDetector.ANOMALY_DETECTORS_INDEX, MapperService.SINGLE_MAPPING_NAME, - detector.getDetectorId(), + detectorToCheck.getDetectorId(), UNASSIGNED_SEQ_NO, 0, -1, @@ -154,21 +148,17 @@ private String setupDetector(boolean responseExists) throws IOException { ); return null; - }).when(client).get(any(), any()); - return detector.getDetectorId(); + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + return detectorToCheck.getDetectorId(); } public void testGetPartitionNumber() throws IOException, InterruptedException { String detectorId = setupDetector(true); - int partitionNumber = stateManager.getPartitionNumber(detectorId); + int partitionNumber = stateManager + .getPartitionNumber(detectorId, TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); assertEquals(2, partitionNumber); } - public void testGetResponseNotFound() throws IOException, InterruptedException { - String detectorId = setupDetector(false); - expectThrows(AnomalyDetectionException.class, () -> stateManager.getPartitionNumber(detectorId)); - } - public void testShouldMute() { String nodeId = "123"; assertTrue(!stateManager.isMuted(nodeId)); @@ -214,10 +204,29 @@ public void testMaintenancRemove() throws IOException { } public void testHasRunningQuery() throws IOException { + stateManager = new ADStateManager( + client, + xContentRegistry(), + modelManager, + settings, + new ClientUtil(settings, client, throttler, context), + clock, + duration + ); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), null); SearchRequest dummySearchRequest = new SearchRequest(); assertFalse(stateManager.hasRunningQuery(detector)); throttler.insertFilteredQuery(detector.getDetectorId(), dummySearchRequest); assertTrue(stateManager.hasRunningQuery(detector)); } + + public void testGetAnomalyDetector() throws IOException { + String detectorId = setupDetector(true); + stateManager + .getAnomalyDetector( + detectorId, + ActionListener.wrap(asDetector -> { assertEquals(detectorToCheck, asDetector.get()); }, exception -> assertTrue(false)) + ); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index 9832863f..f0a929c8 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Mockito.any; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyDouble; @@ -53,12 +53,12 @@ import java.util.function.Function; import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; -import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; import com.amazon.opendistroforelasticsearch.ad.common.exception.ClientException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; @@ -85,7 +85,6 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.ActionFilters; @@ -135,7 +134,6 @@ public class AnomalyResultTests extends AbstractADTest { private FeatureManager featureQuery; private ModelManager normalModelManager; private Client client; - private AnomalyDetectionIndices anomalyDetectionIndices; private AnomalyDetector detector; private HashRing hashRing; private IndexNameExpressionResolver indexNameResolver; @@ -171,7 +169,7 @@ public void setUp() throws Exception { clusterService = testNodes[0].clusterService; stateManager = mock(ADStateManager.class); // return 2 RCF partitions - when(stateManager.getPartitionNumber(any(String.class))).thenReturn(2); + when(stateManager.getPartitionNumber(any(String.class), any(AnomalyDetector.class))).thenReturn(2); when(stateManager.isMuted(any(String.class))).thenReturn(false); detector = mock(AnomalyDetector.class); @@ -184,19 +182,35 @@ public void setUp() throws Exception { userIndex.add("test*"); when(detector.getIndices()).thenReturn(userIndex); when(detector.getDetectorId()).thenReturn("testDetectorId"); - when(stateManager.getAnomalyDetector(any(String.class))).thenReturn(Optional.of(detector)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); hashRing = mock(HashRing.class); when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(clusterService.state().nodes().getLocalNode())); when(hashRing.build()).thenReturn(true); featureQuery = mock(FeatureManager.class); - when(featureQuery.getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong())) - .thenReturn(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 }))); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 }))); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + normalModelManager = mock(ModelManager.class); - when(normalModelManager.getThresholdingResult(any(String.class), any(String.class), anyDouble())) - .thenReturn(new ThresholdingResult(0, 1.0d)); - when(normalModelManager.getRcfResult(any(String.class), any(String.class), any(double[].class))) - .thenReturn(new RcfResult(0.2, 0, 100)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d)); + return null; + }).when(normalModelManager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0.2, 0, 100)); + return null; + }).when(normalModelManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(normalModelManager.combineRcfResults(any())).thenReturn(new CombinedRcfResult(0, 1.0d)); adID = "123"; rcfModelID = "123-rcf-1"; @@ -253,37 +267,6 @@ public void setupTestNodes(Settings settings) { runner = new ColdStartRunner(); } - @SuppressWarnings("unchecked") - public void setUpSavingAnomalyResultIndex(boolean anomalyResultIndexExists) throws IOException { - anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - doAnswer(invocation -> { - Object[] args = invocation.getArguments(); - assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length >= 1); - - ActionListener listener = null; - - if (args[0] instanceof ActionListener) { - listener = (ActionListener) args[0]; - } - - assertTrue(listener != null); - - listener.onResponse(new CreateIndexResponse(true, true, AnomalyResult.ANOMALY_RESULT_INDEX) { - }); - - return null; - }).when(anomalyDetectionIndices).initAnomalyResultIndexDirectly(any()); - - when(anomalyDetectionIndices.doesAnomalyResultIndexExist()).thenReturn(anomalyResultIndexExists); - } - - public void setupInitResultIndexException(Class exceptionType) throws IOException { - anomalyDetectionIndices = mock(AnomalyDetectionIndices.class); - doThrow(exceptionType).when(anomalyDetectionIndices).initAnomalyResultIndexDirectly(any()); - - when(anomalyDetectionIndices.doesAnomalyResultIndexExist()).thenReturn(false); - } - @Override @After public final void tearDown() throws Exception { @@ -293,7 +276,6 @@ public final void tearDown() throws Exception { runner.shutDown(); runner = null; client = null; - anomalyDetectionIndices = null; super.tearDownLog4jForJUnit(); super.tearDown(); } @@ -309,7 +291,6 @@ private void assertException(PlainActionFuture listener, public void testNormal() throws IOException { - setUpSavingAnomalyResultIndex(false); // These constructors register handler in transport service new RCFResultTransportAction( new ActionFilters(Collections.emptySet()), @@ -322,17 +303,14 @@ public void testNormal() throws IOException { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -377,17 +355,14 @@ public Throwable noModelExceptionTemplate( AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, globalRunner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -441,10 +416,13 @@ public void testADExceptionWhenColdStart() { noModelExceptionTemplate(new AnomalyDetectionException(adID, ""), mockRunner, adID, error); } + @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringColdStart() { ModelManager rcfManager = mock(ModelManager.class); - doThrow(ResourceNotFoundException.class).when(rcfManager).getRcfResult(any(String.class), any(String.class), any(double[].class)); + doThrow(ResourceNotFoundException.class) + .when(rcfManager) + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(rcfManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); ColdStartRunner mockRunner = mock(ColdStartRunner.class); @@ -458,17 +436,14 @@ public void testInsufficientCapacityExceptionDuringColdStart() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, mockRunner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -480,12 +455,13 @@ public void testInsufficientCapacityExceptionDuringColdStart() { assertException(listener, LimitExceededException.class); } + @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringRestoringModel() { ModelManager rcfManager = mock(ModelManager.class); doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonErrorMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) .when(rcfManager) - .getRcfResult(any(String.class), any(String.class), any(double[].class)); + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); // These constructors register handler in transport service new RCFResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, rcfManager, adCircuitBreakerService); @@ -494,17 +470,14 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -535,17 +508,14 @@ public void testThresholdException() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -569,17 +539,14 @@ public void testCircuitBreaker() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, breakerService, adStats ); @@ -636,17 +603,14 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), exceptionTransportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -692,24 +656,26 @@ public void testTemporaryThresholdNodeNotConnectedException() { nodeNotConnectedExceptionTemplate(false, true, 1); } + @SuppressWarnings("unchecked") public void testMute() { ADStateManager muteStateManager = mock(ADStateManager.class); when(muteStateManager.isMuted(any(String.class))).thenReturn(true); - when(muteStateManager.getAnomalyDetector(any(String.class))).thenReturn(Optional.of(detector)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(muteStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, muteStateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -722,7 +688,6 @@ public void testMute() { } public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { - setUpSavingAnomalyResultIndex(anomalyResultIndexExists); // These constructors register handler in transport service new RCFResultTransportAction( new ActionFilters(Collections.emptySet()), @@ -735,17 +700,14 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -890,22 +852,19 @@ public void testOnFailureNull() throws IOException { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, new ColdStartRunner(), - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, null, null, null, null, null, null, null, null, 0, 0, 0, new AtomicInteger(), null + null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null ); listener.onFailure(null); } @@ -915,17 +874,14 @@ public void testColdStartNoTrainingData() throws Exception { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -943,17 +899,14 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -968,32 +921,35 @@ enum FeatureTestMode { AD_EXCEPTION } + @SuppressWarnings("unchecked") public void featureTestTemplate(FeatureTestMode mode) { if (mode == FeatureTestMode.FEATURE_NOT_AVAILABLE) { - when(featureQuery.getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong())) - .thenReturn(new SinglePointFeatures(Optional.empty(), Optional.empty())); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } else if (mode == FeatureTestMode.ILLEGAL_STATE) { - doThrow(IllegalArgumentException.class).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong()); + doThrow(IllegalArgumentException.class) + .when(featureQuery) + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } else if (mode == FeatureTestMode.AD_EXCEPTION) { doThrow(AnomalyDetectionException.class) .when(featureQuery) - .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong()); + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); } AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -1069,17 +1025,14 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, hackedClusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); @@ -1116,24 +1069,55 @@ public void testNullRCFResult() { AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, - client, settings, stateManager, runner, - anomalyDetectionIndices, featureQuery, normalModelManager, hashRing, clusterService, indexNameResolver, - threadPool, adCircuitBreakerService, adStats ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, "123-rcf-0", null, "123", null, null, null, null, null, 0, 0, 0, new AtomicInteger(), null + null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null ); listener.onResponse(null); assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); } + + @SuppressWarnings("unchecked") + public void testAllFeaturesDisabled() { + // doThrow(IllegalArgumentException.class).when(featureQuery) + // .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onFailure(new IllegalArgumentException()); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + when(detector.getEnabledFeatureIds()).thenReturn(Collections.emptyList()); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + stateManager, + runner, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class, AnomalyResultTransportAction.ALL_FEATURES_DISABLED_ERR_MSG); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java index bc6c6141..a1e556dc 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/CronTransportActionTests.java @@ -27,6 +27,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.Version; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.cluster.ClusterName; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java index 09f2ddee..7d60ffce 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/DeleteModelTransportActionTests.java @@ -30,6 +30,7 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; + import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java index 95766caa..9084ee6a 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -33,6 +34,8 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.RcfResult; + +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; @@ -56,6 +59,7 @@ public class RCFResultTests extends ESTestCase { Gson gson = new GsonBuilder().create(); + @SuppressWarnings("unchecked") public void testNormal() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -75,7 +79,12 @@ public void testNormal() { manager, adCircuitBreakerService ); - when(manager.getRcfResult(any(String.class), any(String.class), any(double[].class))).thenReturn(new RcfResult(0, 0, 25)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0, 0, 25)); + return null; + }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + when(adCircuitBreakerService.isOpen()).thenReturn(false); final PlainActionFuture future = new PlainActionFuture<>(); @@ -87,6 +96,7 @@ public void testNormal() { assertEquals(25, response.getForestSize(), 0.001); } + @SuppressWarnings("unchecked") public void testExecutionException() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -106,7 +116,9 @@ public void testExecutionException() { manager, adCircuitBreakerService ); - doThrow(NullPointerException.class).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class)); + doThrow(NullPointerException.class) + .when(manager) + .getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(adCircuitBreakerService.isOpen()).thenReturn(false); final PlainActionFuture future = new PlainActionFuture<>(); @@ -172,6 +184,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { ); } + @SuppressWarnings("unchecked") public void testCircuitBreaker() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -191,7 +204,11 @@ public void testCircuitBreaker() { manager, breakerService ); - when(manager.getRcfResult(any(String.class), any(String.class), any(double[].class))).thenReturn(new RcfResult(0, 0, 25)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RcfResult(0, 0, 25)); + return null; + }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(breakerService.isOpen()).thenReturn(true); final PlainActionFuture future = new PlainActionFuture<>(); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java index 4c2ca200..4a992126 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ThresholdResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Collections; @@ -31,6 +31,8 @@ import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingResult; + +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; @@ -50,6 +52,7 @@ public class ThresholdResultTests extends ESTestCase { + @SuppressWarnings("unchecked") public void testNormal() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -63,7 +66,11 @@ public void testNormal() { ModelManager manager = mock(ModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); - when(manager.getThresholdingResult(any(String.class), any(String.class), anyDouble())).thenReturn(new ThresholdingResult(0, 1.0d)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d)); + return null; + }).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); final PlainActionFuture future = new PlainActionFuture<>(); ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2); @@ -74,6 +81,7 @@ public void testNormal() { assertEquals(1, response.getConfidence(), 0.001); } + @SuppressWarnings("unchecked") public void testExecutionException() { TransportService transportService = new TransportService( Settings.EMPTY, @@ -87,7 +95,9 @@ public void testExecutionException() { ModelManager manager = mock(ModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); - doThrow(NullPointerException.class).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble()); + doThrow(NullPointerException.class) + .when(manager) + .getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); final PlainActionFuture future = new PlainActionFuture<>(); ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2);