From 2f182316036f6ee206edd7d6c482f830784f0bff Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 22 Oct 2020 15:54:52 -0700 Subject: [PATCH] Fix for stats API (#287) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add more tests * Fix for stats API This PR fixes two issues for the stats API. First, we didn't propagate multi-entity detectors' models execution exceptions for the remote invocation.  This problem may impact stats' API ability to report the total failures count and thus hide an issue we should have reported during monitoring.  This PR fixes the issue by collecting model host nodes' exceptions from coordinating nodes. Second, we didn't show active multi-entity detectors' models information on stats API.  This PR places this information into stats API output. This PR also adds unit tests for ModelManager. Testing done: 1. added unit tests 2. manually verified the two issues are resolved. --- build.gradle | 3 - .../ad/AnomalyDetectorPlugin.java | 2 +- .../ad/caching/CacheBuffer.java | 6 + .../ad/caching/EntityCache.java | 9 + .../ad/caching/PriorityCache.java | 14 + .../ad/feature/SearchFeatureDao.java | 4 +- .../ad/ml/ModelManager.java | 5 +- .../stats/suppliers/ModelsOnNodeSupplier.java | 11 +- .../AnomalyResultTransportAction.java | 121 +++-- .../ad/AbstractADTest.java | 5 +- .../ad/TestHelpers.java | 2 +- .../ad/caching/CacheBufferTests.java | 1 + .../ad/caching/PriorityCacheTests.java | 5 + .../ad/ml/ModelManagerTests.java | 61 +++ .../ad/stats/ADStatsTests.java | 19 +- .../suppliers/ModelsOnNodeSupplierTests.java | 25 +- .../ADStatsNodesTransportActionTests.java | 8 +- .../ad/transport/MultientityResultTests.java | 498 ++++++++++++++++++ .../ad/util/MLUtil.java | 94 ++-- 19 files changed, 806 insertions(+), 87 deletions(-) create mode 100644 src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java diff --git a/build.gradle b/build.gradle index d2aa92b7..4b37620c 100644 --- a/build.gradle +++ b/build.gradle @@ -261,9 +261,6 @@ List jacocoExclusions = [ 'com.amazon.opendistroforelasticsearch.ad.transport.SearchAnomalyResultTransportAction*', // TODO: hc caused coverage to drop - //'com.amazon.opendistroforelasticsearch.ad.ml.ModelManager', - 'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction', - 'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction.EntityResultListener', 'com.amazon.opendistroforelasticsearch.ad.NodeStateManager', 'com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler', 'com.amazon.opendistroforelasticsearch.ad.transport.EntityProfileTransportAction*', diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index a9dbcf0a..85e921eb 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -442,7 +442,7 @@ public Collection createComponents( .>builder() .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))) + .put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))) .put( StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), new ADStat<>(true, new IndexStatusSupplier(indexUtils, AnomalyDetector.ANOMALY_DETECTORS_INDEX)) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java index 192013b7..cdec3bd6 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBuffer.java @@ -20,10 +20,12 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Comparator; +import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentSkipListSet; +import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; @@ -525,4 +527,8 @@ public boolean expired(Duration stateTtl) { public String getDetectorId() { return detectorId; } + + public List> getAllModels() { + return items.values().stream().collect(Collectors.toList()); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java index fd42cc13..9915f605 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/EntityCache.java @@ -15,6 +15,8 @@ package com.amazon.opendistroforelasticsearch.ad.caching; +import java.util.List; + import com.amazon.opendistroforelasticsearch.ad.CleanState; import com.amazon.opendistroforelasticsearch.ad.MaintenanceState; import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; @@ -72,4 +74,11 @@ public interface EntityCache extends MaintenanceState, CleanState { * @return RCF model total updates of specific entity */ long getTotalUpdates(String detectorId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java index 43f11aa9..508d8e42 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCache.java @@ -23,6 +23,8 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -554,4 +556,16 @@ public int getTotalActiveEntities() { activeEnities.values().stream().forEach(cacheBuffer -> { total.addAndGet(cacheBuffer.getActiveEntities()); }); return total.get(); } + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + @Override + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + return states; + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java index c696e980..b445c0a8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java @@ -812,8 +812,8 @@ public void getFeaturesByEntities( new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, termsListener, false) ); - } catch (IOException e) { - throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true); + } catch (Exception e) { + throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, false); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index dfcf7e6c..cfd759ab 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -20,6 +20,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -1022,7 +1023,9 @@ public void processEntityCheckpoint( modelState.setLastCheckpointTime(clock.instant().minus(checkpointInterval)); } - assert (modelState.getModel() != null); + if (modelState.getModel() == null) { + modelState.setModel(new EntityModel(modelId, new ArrayDeque<>(), null, null)); + } maybeTrainBeforeScore(modelState, entityName); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java index d8cb4b66..669b5da8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplier.java @@ -27,7 +27,9 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; /** @@ -35,6 +37,7 @@ */ public class ModelsOnNodeSupplier implements Supplier>> { private ModelManager modelManager; + private CacheProvider cache; /** * Set that contains the model stats that should be exposed. @@ -45,16 +48,18 @@ public class ModelsOnNodeSupplier implements Supplier>> * Constructor * * @param modelManager object that manages the model partitions hosted on the node + * @param cache object that manages multi-entity detectors' models */ - public ModelsOnNodeSupplier(ModelManager modelManager) { + public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache) { this.modelManager = modelManager; + this.cache = cache; } @Override public List> get() { List> values = new ArrayList<>(); - modelManager - .getAllModels() + Stream + .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) .forEach( modelState -> values .add( diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 514709ac..6bdeb9c8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -104,6 +105,8 @@ public class AnomalyResultTransportAction extends HandledTransportAction listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); ActionListener original = listener; listener = ActionListener.wrap(original::onResponse, e -> { @@ -233,7 +235,6 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< listener.onFailure(new LimitExceededException(adID, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); return; } - try { stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); } catch (Exception ex) { @@ -297,7 +298,7 @@ private ActionListener> onGetDetector( ) ); } else { - entityFeatures + Set>> node2Entities = entityFeatures .entrySet() .stream() .collect( @@ -307,26 +308,29 @@ private ActionListener> onGetDetector( Collectors.toMap(Entry::getKey, Entry::getValue) ) ) - .entrySet() - .stream() - .forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime), - this.option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), adID), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); + .entrySet(); + + int nodeCount = node2Entities.size(); + AtomicInteger responseCount = new AtomicInteger(); + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + EntityResultAction.NAME, + new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime), + this.option, + new ActionListenerResponseHandler<>( + new EntityResultListener(node.getId(), adID, responseCount, nodeCount, failure, listener), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); } - listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList())); }, exception -> handleFailure(exception, listener, adID)); threadPool @@ -482,7 +486,7 @@ private ActionListener onFeatureResponse( private void handleFailure(Exception exception, ActionListener listener, String adID) { if (exception instanceof IndexNotFoundException) { - listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true)); + listener.onFailure(new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), true)); } else if (exception instanceof EndRunException) { // invalid feature query listener.onFailure(exception); @@ -555,7 +559,7 @@ private void findException(Throwable cause, String adID, AtomicReference failure) { - LOG.error(new ParameterizedMessage("Received an error from node {} while fetching anomaly grade for {}", nodeID, adID), e); + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); if (e == null) { return; } @@ -801,6 +805,8 @@ private void handlePredictionFailure(Exception e, String adID, String nodeID, At /** * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. * * @param e exception * @return true if we get disconnected from the node or the node is not in the @@ -811,7 +817,8 @@ private boolean hasConnectionIssue(Throwable e) { || e instanceof NodeClosedException || e instanceof ReceiveTimeoutTransportException || e instanceof NodeNotConnectedException - || e instanceof ConnectException; + || e instanceof ConnectException + || e instanceof ActionNotFoundTransportException; } private void handleConnectionException(String node) { @@ -1015,18 +1022,45 @@ private Optional coldStartIfNoCheckPoint(AnomalyDetec class EntityResultListener implements ActionListener { private String nodeId; private final String adID; + private AtomicInteger responseCount; + private int nodeCount; + private ActionListener listener; + private List ackResponses; + private AtomicReference failure; - EntityResultListener(String nodeId, String adID) { + EntityResultListener( + String nodeId, + String adID, + AtomicInteger responseCount, + int nodeCount, + AtomicReference failure, + ActionListener listener + ) { this.nodeId = nodeId; this.adID = adID; + this.responseCount = responseCount; + this.nodeCount = nodeCount; + this.failure = failure; + this.listener = listener; + this.ackResponses = new ArrayList<>(); } @Override public void onResponse(AcknowledgedResponse response) { - stateManager.resetBackpressureCounter(nodeId); - if (response.isAcknowledged() == false) { - LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); - stateManager.addPressure(nodeId); + try { + stateManager.resetBackpressureCounter(nodeId); + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); + stateManager.addPressure(nodeId); + } else { + ackResponses.add(response); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + } finally { + if (nodeCount == responseCount.incrementAndGet()) { + handleEntityResponses(); + } } } @@ -1035,13 +1069,28 @@ public void onFailure(Exception e) { if (e == null) { return; } - Throwable cause = ExceptionsHelper.unwrapCause(e); - // in case of connection issue or the other node has no multi-entity - // transport actions (e.g., blue green deployment) - if (hasConnectionIssue(cause) || cause instanceof ActionNotFoundTransportException) { - handleConnectionException(nodeId); + try { + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); + + handlePredictionFailure(e, adID, nodeId, failure); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + } finally { + if (nodeCount == responseCount.incrementAndGet()) { + handleEntityResponses(); + } + } + } + + private void handleEntityResponses() { + if (failure.get() != null) { + listener.onFailure(failure.get()); + } else if (ackResponses.isEmpty()) { + listener.onFailure(new InternalFailure(adID, NO_ACK_ERR)); + } else { + listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList())); } - LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); } } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java index 9d15ccd0..c8581808 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java @@ -227,6 +227,9 @@ public void setupTestNodes(Settings settings) { } public void tearDownTestNodes() { + if (testNodes == null) { + return; + } for (FakeNode testNode : testNodes) { testNode.close(); } @@ -238,7 +241,7 @@ public void assertException( Class exceptionType, String msg ) { - Exception e = expectThrows(exceptionType, () -> listener.actionGet()); + Exception e = expectThrows(exceptionType, () -> listener.actionGet(20_000)); assertThat(e.getMessage(), containsString(msg)); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java index e8f1d6b3..7f6e2dd1 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java @@ -226,7 +226,7 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields(String de ImmutableList.of(randomFeature(true)), randomQuery(), randomIntervalTimeConfiguration(), - randomIntervalTimeConfiguration(), + new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), randomIntBetween(1, 2000), null, randomInt(), diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java index afabdc9e..de93b6ce 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/CacheBufferTests.java @@ -167,6 +167,7 @@ public void testMaintenance() { cacheBuffer.put(modelId3, MLUtil.randomModelState(initialPriority, modelId3)); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); + assertEquals(3, cacheBuffer.getAllModels().size()); when(clock.instant()).thenReturn(Instant.MAX); cacheBuffer.maintenance(); assertEquals(0, cacheBuffer.getActiveEntities()); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java index 798e781f..c1f3dc1d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/caching/PriorityCacheTests.java @@ -186,6 +186,7 @@ public void testCacheHit() { // cache miss due to door keeper assertEquals(null, cacheProvider.get(modelId1, detector, point, entityName)); assertEquals(1, cacheProvider.getTotalActiveEntities()); + assertEquals(1, cacheProvider.getAllModels().size()); ModelState hitState = cacheProvider.get(modelId1, detector, point, entityName); assertEquals(detectorId, hitState.getDetectorId()); EntityModel model = hitState.getModel(); @@ -248,10 +249,12 @@ public void testSharedCache() { } assertEquals(2, cacheProvider.getActiveEntities(detectorId2)); assertEquals(3, cacheProvider.getTotalActiveEntities()); + assertEquals(3, cacheProvider.getAllModels().size()); when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); cacheProvider.maintenance(); assertEquals(2, cacheProvider.getTotalActiveEntities()); + assertEquals(2, cacheProvider.getAllModels().size()); assertEquals(1, cacheProvider.getActiveEntities(detectorId2)); } @@ -377,9 +380,11 @@ public void testExpiredCacheBuffer() { cacheProvider.get(modelId2, detector, point, entityName); } assertEquals(2, cacheProvider.getTotalActiveEntities()); + assertEquals(2, cacheProvider.getAllModels().size()); when(clock.instant()).thenReturn(Instant.now()); cacheProvider.maintenance(); assertEquals(0, cacheProvider.getTotalActiveEntities()); + assertEquals(0, cacheProvider.getAllModels().size()); for (int i = 0; i < 2; i++) { // doorkeeper should have been reset diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 2527e756..692f1d93 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -71,6 +71,8 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; import com.amazon.opendistroforelasticsearch.ad.MemoryTracker; import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; @@ -164,6 +166,7 @@ public class ModelManagerTests { private double[] attribution; private double[] point; private DiVector attributionVec; + private String entityName; @Mock private ActionListener rcfResultListener; @@ -171,6 +174,7 @@ public class ModelManagerTests { @Mock private ActionListener thresholdResultListener; private MemoryTracker memoryTracker; + private Instant now; @Before public void setup() { @@ -232,6 +236,9 @@ public void setup() { modelPartitioner = spy(new ModelPartitioner(numSamples, numTrees, nodeFilter, memoryTracker)); + now = Instant.now(); + when(clock.instant()).thenReturn(now); + modelManager = spy( new ModelManager( rcfSerde, @@ -284,6 +291,8 @@ public void setup() { listener.onResponse(Optional.of(failCheckpoint)); return null; }).when(checkpointDao).getModelCheckpoint(eq(failModelId), any(ActionListener.class)); + + entityName = "1.0.2.3"; } private Object[] getDetectorIdForModelIdData() { @@ -1188,4 +1197,56 @@ public void getPreviewResults_returnAnomalies_forLastAnomaly() { public void getPreviewResults_throwIllegalArgument_forInvalidInput() { modelManager.getPreviewResults(new double[0][0]); } + + @Test + public void getNullState() { + assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity("", new double[] {}, "", null, "")); + } + + @Test + public void getEmptyStateFullSamples() { + ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) + ); + assertEquals(numMinSamples, state.getModel().getSamples().size()); + } + + @Test + public void getEmptyStateNotFullSamples() { + ModelState state = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId) + ); + assertEquals(numMinSamples, state.getModel().getSamples().size()); + } + + @Test + public void scoreSamples() { + ModelState state = MLUtil.randomNonEmptyModelState(); + modelManager.getAnomalyResultForEntity(detectorId, new double[] { -1 }, entityName, state, modelId); + assertEquals(0, state.getModel().getSamples().size()); + assertEquals(now, state.getLastUsedTime()); + } + + @Test + public void processEmptyCheckpoint() { + ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples - 1); + modelManager.processEntityCheckpoint(Optional.empty(), modelId, entityName, modelState); + assertEquals(now.minus(checkpointInterval), modelState.getLastCheckpointTime()); + } + + @Test + public void processNonEmptyCheckpoint() { + EntityModel model = MLUtil.createNonEmptyModel(modelId); + ModelState modelState = MLUtil.randomModelStateWithSample(false, numMinSamples); + Instant checkpointTime = Instant.ofEpochMilli(1000); + modelManager + .processEntityCheckpoint(Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), modelId, entityName, modelState); + assertEquals(checkpointTime, modelState.getLastCheckpointTime()); + assertEquals(0, modelState.getModel().getSamples().size()); + assertEquals(now, modelState.getLastUsedTime()); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java index 356ec8c1..3e17e984 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/ADStatsTests.java @@ -34,6 +34,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelState; @@ -58,6 +63,9 @@ public class ADStatsTests extends ESTestCase { @Mock private ModelManager modelManager; + @Mock + private CacheProvider cacheProvider; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -76,6 +84,15 @@ public void setup() { ); when(modelManager.getAllModels()).thenReturn(modelsInformation); + + ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); + IndexUtils indexUtils = mock(IndexUtils.class); when(indexUtils.getIndexHealthStatus(anyString())).thenReturn("yellow"); @@ -90,7 +107,7 @@ public void setup() { statsMap = new HashMap>() { { put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))); put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index bd932fba..e21a8be6 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.ad.stats.suppliers; import static com.amazon.opendistroforelasticsearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.time.Clock; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -31,6 +33,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel; import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.ml.ModelState; @@ -41,10 +48,14 @@ public class ModelsOnNodeSupplierTests extends ESTestCase { private HybridThresholdingModel thresholdingModel; private List> expectedResults; private Clock clock; + private List> entityModelsInformation; @Mock private ModelManager modelManager; + @Mock + private CacheProvider cacheProvider; + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -64,16 +75,24 @@ public void setup() { ); when(modelManager.getAllModels()).thenReturn(expectedResults); + + ModelState entityModel1 = MLUtil.randomNonEmptyModelState(); + ModelState entityModel2 = MLUtil.randomNonEmptyModelState(); + + entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); } @Test public void testGet() { - ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager); + ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager, cacheProvider); List> results = modelsOnNodeSupplier.get(); assertEquals( "get fails to return correct result", - expectedResults - .stream() + Stream + .concat(expectedResults.stream(), entityModelsInformation.stream()) .map( modelState -> modelState .getModelStateAsMap() diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java index 8c66031b..40d8f607 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportActionTests.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.ad.transport; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.time.Clock; import java.util.Arrays; @@ -34,6 +35,8 @@ import org.junit.Before; import org.junit.Test; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; import com.amazon.opendistroforelasticsearch.ad.stats.ADStat; import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; @@ -69,6 +72,9 @@ public void setUp() throws Exception { indexNameResolver ); ModelManager modelManager = mock(ModelManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); clusterStatName1 = "clusterStat1"; clusterStatName2 = "clusterStat2"; @@ -78,7 +84,7 @@ public void setUp() throws Exception { statsMap = new HashMap>() { { put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager))); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider))); put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java new file mode 100644 index 00000000..6e7804d3 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/MultientityResultTests.java @@ -0,0 +1,498 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Function; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportInterceptor; +import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import test.com.amazon.opendistroforelasticsearch.ad.util.MLUtil; + +import com.amazon.opendistroforelasticsearch.ad.AbstractADTest; +import com.amazon.opendistroforelasticsearch.ad.NodeStateManager; +import com.amazon.opendistroforelasticsearch.ad.TestHelpers; +import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService; +import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider; +import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache; +import com.amazon.opendistroforelasticsearch.ad.cluster.HashRing; +import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; +import com.amazon.opendistroforelasticsearch.ad.common.exception.InternalFailure; +import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; +import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; +import com.amazon.opendistroforelasticsearch.ad.feature.SearchFeatureDao; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; +import com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner; +import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingResult; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; +import com.amazon.opendistroforelasticsearch.ad.stats.ADStat; +import com.amazon.opendistroforelasticsearch.ad.stats.ADStats; +import com.amazon.opendistroforelasticsearch.ad.stats.StatNames; +import com.amazon.opendistroforelasticsearch.ad.stats.suppliers.CounterSupplier; +import com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler; +import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; +import com.amazon.opendistroforelasticsearch.ad.util.IndexUtils; + +public class MultientityResultTests extends AbstractADTest { + private AnomalyResultTransportAction action; + private AnomalyResultRequest request; + private TransportInterceptor entityResultInterceptor; + private Clock clock; + private AnomalyDetector detector; + private NodeStateManager stateManager; + private static Settings settings; + private TransportService transportService; + private SearchFeatureDao searchFeatureDao; + private Client client; + private FeatureManager featureQuery; + private ModelManager normalModelManager; + private ModelPartitioner normalModelPartitioner; + private HashRing hashRing; + private ClusterService clusterService; + private IndexNameExpressionResolver indexNameResolver; + private ADCircuitBreakerService adCircuitBreakerService; + private ADStats adStats; + private ThreadPool mockThreadPool; + private String detectorId; + private Instant now; + private String modelId; + private MultiEntityResultHandler anomalyResultHandler; + private CheckpointDao checkpointDao; + private CacheProvider provider; + private AnomalyDetectionIndices indexUtil; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings({ "serial", "unchecked" }) + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + now = Instant.now(); + clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + detectorId = "123"; + modelId = "abc"; + String categoryField = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); + + stateManager = mock(NodeStateManager.class); + // make sure parameters are not null, otherwise this mock won't get invoked + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + when(stateManager.getLastIndexThrottledTime()).thenReturn(Instant.MIN); + + settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); + + request = new AnomalyResultRequest(detectorId, 100, 200); + + transportService = mock(TransportService.class); + + client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(settings); + mockThreadPool = mock(ThreadPool.class); + setUpADThreadPool(mockThreadPool); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + featureQuery = mock(FeatureManager.class); + + normalModelManager = mock(ModelManager.class); + when(normalModelManager.getEntityModelId(anyString(), anyString())).thenReturn(modelId); + + normalModelPartitioner = mock(ModelPartitioner.class); + + hashRing = mock(HashRing.class); + + clusterService = mock(ClusterService.class); + + indexNameResolver = new IndexNameExpressionResolver(); + + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + IndexUtils indexUtils = new IndexUtils(client, mock(ClientUtil.class), clusterService, indexNameResolver); + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + adStats = new ADStats(indexUtils, normalModelManager, statsMap); + + searchFeatureDao = mock(SearchFeatureDao.class); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + searchFeatureDao + ); + + anomalyResultHandler = mock(MultiEntityResultHandler.class); + checkpointDao = mock(CheckpointDao.class); + provider = mock(CacheProvider.class); + indexUtil = mock(AnomalyDetectionIndices.class); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testQueryError() { + // non-EndRunException won't stop action from running + when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener + .onFailure( + new EndRunException( + detectorId, + CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).getAnomalyDetector(anyString(), any(ActionListener.class)); + + assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); + } + + public void testIndexNotFound() { + // non-EndRunException won't stop action from running + when(stateManager.fetchColdStartException(anyString())).thenReturn(Optional.of(new AnomalyDetectionException(detectorId, ""))); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new IndexNotFoundException("", "")); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class, AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG); + } + + public void testColdStartEndRunException() { + when(stateManager.fetchColdStartException(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonErrorMessages.INVALID_SEARCH_QUERY_MSG); + } + + public void testEmptyFeatures() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(new HashMap()); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + } + + private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse(response); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private TransportResponseHandler unackEntityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new AcknowledgedResponse(false)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private void setUpEntityResult() { + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + anomalyResultHandler, + checkpointDao, + provider, + stateManager, + settings, + clock, + indexUtil + ); + + EntityCache entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + when(entityCache.get(any(), any(), any(), anyString())).thenReturn(MLUtil.randomNonEmptyModelState()); + + when(normalModelManager.getAnomalyResultForEntity(anyString(), any(), anyString(), any(), anyString())) + .thenReturn(new ThresholdingResult(0, 1, 1)); + } + + private void setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor + ) { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + Map features = new HashMap(); + features.put("1.0.2.3", new double[] { 0 }); + features.put("2.0.2.3", new double[] { 1 }); + listener.onResponse(features); + return null; + }).when(searchFeatureDao).getFeaturesByEntities(any(), anyLong(), anyLong(), any()); + + entityResultInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @SuppressWarnings("unchecked") + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (action.equals(EntityResultAction.NAME)) { + sender + .sendRequest( + connection, + action, + request, + options, + interceptor.apply((TransportResponseHandler) handler) + ); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + setupTestNodes(settings, entityResultInterceptor); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNode(any(String.class))).thenReturn(Optional.of(testNodes[1].discoveryNode())); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + stateManager, + featureQuery, + normalModelManager, + normalModelPartitioner, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + searchFeatureDao + ); + } + + public void testNonEmptyFeatures() { + setUpTransportInterceptor(this::entityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(0d, response.getAnomalyGrade(), 0.01); + } + + public void testCircuitBreakerOpen() { + setUpTransportInterceptor(this::entityResultHandler); + + ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + when(openBreaker.isOpen()).thenReturn(true); + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + openBreaker, + anomalyResultHandler, + checkpointDao, + provider, + stateManager, + settings, + clock, + indexUtil + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + assertException(listener, LimitExceededException.class, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); + } + + public void testNotAck() { + setUpTransportInterceptor(this::unackEntityResultHandler); + setUpEntityResult(); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertException(listener, InternalFailure.class, AnomalyResultTransportAction.NO_ACK_ERR); + verify(stateManager, times(1)).addPressure(anyString()); + } +} diff --git a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java index d8431339..71288db1 100644 --- a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java +++ b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/MLUtil.java @@ -38,6 +38,7 @@ */ public class MLUtil { private static Random random = new Random(42); + private static int minSampleSize = AnomalyDetectorSettings.NUM_MIN_SAMPLES; private static String randomString(int targetStringLength) { int leftLimit = 97; // letter 'a' @@ -58,54 +59,79 @@ public static Queue createQueueSamples(int size) { } public static ModelState randomModelState() { - return randomModelState(random.nextBoolean(), random.nextFloat(), randomString(15)); + return randomModelState(random.nextBoolean(), random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); } - public static ModelState randomModelState(boolean fullModel, float priority, String modelId) { + public static ModelState randomModelState(boolean fullModel, float priority, String modelId, int sampleSize) { String detectorId = randomString(5); - Queue samples = createQueueSamples(random.nextInt(128)); EntityModel model = null; if (fullModel) { - RandomCutForest rcf = RandomCutForest - .builder() - .dimensions(1) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) - .numberOfTrees(AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES) - .lambda(AnomalyDetectorSettings.TIME_DECAY) - .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) - .parallelExecutionEnabled(false) - .build(); - int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; - double[] scores = new double[numDataPoints]; - for (int j = 0; j < numDataPoints; j++) { - double[] dataPoint = new double[] { random.nextDouble() }; - scores[j] = rcf.getAnomalyScore(dataPoint); - rcf.update(dataPoint); - } - - double[] nonZeroScores = DoubleStream.of(scores).filter(score -> score > 0).toArray(); - ThresholdingModel threshold = new HybridThresholdingModel( - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, - AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, - AnomalyDetectorSettings.THRESHOLD_NUM_LOGNORMAL_QUANTILES, - AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, - AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES - ); - threshold.train(nonZeroScores); - model = new EntityModel(modelId, samples, rcf, threshold); + model = createNonEmptyModel(modelId, sampleSize); } else { - model = new EntityModel(modelId, samples, null, null); + model = createEmptyModel(modelId, sampleSize); } return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), Clock.systemUTC(), priority); } public static ModelState randomNonEmptyModelState() { - return randomModelState(true, random.nextFloat(), randomString(15)); + return randomModelState(true, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); + } + + public static ModelState randomEmptyModelState() { + return randomModelState(false, random.nextFloat(), randomString(15), random.nextInt(minSampleSize)); } public static ModelState randomModelState(float priority, String modelId) { - return randomModelState(random.nextBoolean(), priority, modelId); + return randomModelState(random.nextBoolean(), priority, modelId, random.nextInt(minSampleSize)); + } + + public static ModelState randomModelStateWithSample(boolean fullModel, int sampleSize) { + return randomModelState(fullModel, random.nextFloat(), randomString(15), sampleSize); + } + + public static EntityModel createEmptyModel(String modelId, int sampleSize) { + Queue samples = createQueueSamples(sampleSize); + return new EntityModel(modelId, samples, null, null); + } + + public static EntityModel createEmptyModel(String modelId) { + return createEmptyModel(modelId, random.nextInt(minSampleSize)); + } + + public static EntityModel createNonEmptyModel(String modelId, int sampleSize) { + Queue samples = createQueueSamples(sampleSize); + RandomCutForest rcf = RandomCutForest + .builder() + .dimensions(1) + .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .numberOfTrees(AnomalyDetectorSettings.MULTI_ENTITY_NUM_TREES) + .lambda(AnomalyDetectorSettings.TIME_DECAY) + .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) + .parallelExecutionEnabled(false) + .build(); + int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; + double[] scores = new double[numDataPoints]; + for (int j = 0; j < numDataPoints; j++) { + double[] dataPoint = new double[] { random.nextDouble() }; + scores[j] = rcf.getAnomalyScore(dataPoint); + rcf.update(dataPoint); + } + + double[] nonZeroScores = DoubleStream.of(scores).filter(score -> score > 0).toArray(); + ThresholdingModel threshold = new HybridThresholdingModel( + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, + AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, + AnomalyDetectorSettings.THRESHOLD_NUM_LOGNORMAL_QUANTILES, + AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, + AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES + ); + threshold.train(nonZeroScores); + return new EntityModel(modelId, samples, rcf, threshold); + } + + public static EntityModel createNonEmptyModel(String modelId) { + return createNonEmptyModel(modelId, random.nextInt(minSampleSize)); } }