diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index 352f9a86..e7605058 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -53,6 +53,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.returntypes.DiVector; import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; import com.google.gson.Gson; @@ -206,28 +207,48 @@ public ModelManager( * * Final RCF score is calculated by averaging scores weighted by model size (number of trees). * Confidence is the weighted average of confidence with confidence for missing models being 0. + * Attribution is normalized weighted average for the most recent feature dimensions. * * @param rcfResults RCF results from partitioned models + * @param numFeatures number of features for attribution * @return combined RCF result */ - public CombinedRcfResult combineRcfResults(List rcfResults) { + public CombinedRcfResult combineRcfResults(List rcfResults, int numFeatures) { CombinedRcfResult combinedResult = null; if (rcfResults.isEmpty()) { - combinedResult = new CombinedRcfResult(0, 0); + combinedResult = new CombinedRcfResult(0, 0, new double[0]); } else { int totalForestSize = rcfResults.stream().mapToInt(RcfResult::getForestSize).sum(); if (totalForestSize == 0) { - combinedResult = new CombinedRcfResult(0, 0); + combinedResult = new CombinedRcfResult(0, 0, new double[0]); } else { double score = rcfResults.stream().mapToDouble(r -> r.getScore() * r.getForestSize()).sum() / totalForestSize; double confidence = rcfResults.stream().mapToDouble(r -> r.getConfidence() * r.getForestSize()).sum() / Math .max(rcfNumTrees, totalForestSize); - combinedResult = new CombinedRcfResult(score, confidence); + double[] attribution = combineAttribution(rcfResults, numFeatures, totalForestSize); + combinedResult = new CombinedRcfResult(score, confidence, attribution); } } return combinedResult; } + private double[] combineAttribution(List rcfResults, int numFeatures, int totalForestSize) { + double[] combined = new double[numFeatures]; + double sum = 0; + for (RcfResult result : rcfResults) { + double[] attribution = result.getAttribution(); + for (int i = 0; i < numFeatures; i++) { + double attr = attribution[attribution.length - numFeatures + i] * result.getForestSize() / totalForestSize; + combined[i] += attr; + sum += attr; + } + } + for (int i = 0; i < numFeatures; i++) { + combined[i] /= sum; + } + return combined; + } + /** * Gets the detector id from the model id. * @@ -349,13 +370,25 @@ public void getRcfResult(String detectorId, String modelId, double[] point, Acti } private void getRcfResult(ModelState modelState, double[] point, ActionListener listener) { + modelState.setLastUsedTime(clock.instant()); + RandomCutForest rcf = modelState.getModel(); double score = rcf.getAnomalyScore(point); double confidence = computeRcfConfidence(rcf); int forestSize = rcf.getNumberOfTrees(); + double[] attribution = getAnomalyAttribution(rcf, point); rcf.update(point); - modelState.setLastUsedTime(clock.instant()); - listener.onResponse(new RcfResult(score, confidence, forestSize)); + listener.onResponse(new RcfResult(score, confidence, forestSize, attribution)); + } + + private double[] getAnomalyAttribution(RandomCutForest rcf, double[] point) { + DiVector vec = rcf.getAnomalyAttribution(point); + vec.renormalize(1d); + double[] attribution = new double[vec.getDimensions()]; + for (int i = 0; i < attribution.length; i++) { + attribution[i] = vec.getHighLowSum(i); + } + return attribution; } private Optional> restoreCheckpoint(Optional rcfCheckpoint, String modelId, String detectorId) { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResult.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResult.java index 8164fefc..f384f4f9 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResult.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResult.java @@ -15,6 +15,7 @@ package com.amazon.opendistroforelasticsearch.ad.ml; +import java.util.Arrays; import java.util.Objects; /** @@ -25,6 +26,7 @@ public class RcfResult { private final double score; private final double confidence; private final int forestSize; + private final double[] attribution; /** * Constructor with all arguments. @@ -32,11 +34,13 @@ public class RcfResult { * @param score RCF score * @param confidence RCF confidence * @param forestSize number of RCF trees used for the score + * @param attribution anomaly score attribution */ - public RcfResult(double score, double confidence, int forestSize) { + public RcfResult(double score, double confidence, int forestSize, double[] attribution) { this.score = score; this.confidence = confidence; this.forestSize = forestSize; + this.attribution = attribution; } /** @@ -66,6 +70,15 @@ public int getForestSize() { return forestSize; } + /** + * Returns anomaly score attribution. + * + * @return anomaly score attribution + */ + public double[] getAttribution() { + return attribution; + } + @Override public boolean equals(Object o) { if (this == o) @@ -75,11 +88,12 @@ public boolean equals(Object o) { RcfResult that = (RcfResult) o; return Objects.equals(this.score, that.score) && Objects.equals(this.confidence, that.confidence) - && Objects.equals(this.forestSize, that.forestSize); + && Objects.equals(this.forestSize, that.forestSize) + && Arrays.equals(this.attribution, that.attribution); } @Override public int hashCode() { - return Objects.hash(score, confidence, forestSize); + return Objects.hash(score, confidence, forestSize, attribution); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResult.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResult.java index 412e60c4..39cf36fc 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResult.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResult.java @@ -15,6 +15,7 @@ package com.amazon.opendistroforelasticsearch.ad.ml.rcf; +import java.util.Arrays; import java.util.Objects; /** @@ -24,16 +25,19 @@ public class CombinedRcfResult { private final double score; private final double confidence; + private final double[] attribution; /** * Constructor with all arguments. * * @param score combined RCF score * @param confidence confidence of the score + * @param attribution score attribution normalized to 1 */ - public CombinedRcfResult(double score, double confidence) { + public CombinedRcfResult(double score, double confidence, double[] attribution) { this.score = score; this.confidence = confidence; + this.attribution = attribution; } /** @@ -54,6 +58,15 @@ public double getConfidence() { return confidence; } + /** + * Return score attribution normalized to 1. + * + * @return score attribution normalized to 1 + */ + public double[] getAttribution() { + return attribution; + } + @Override public boolean equals(Object o) { if (this == o) @@ -61,11 +74,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; CombinedRcfResult that = (CombinedRcfResult) o; - return Objects.equals(this.score, that.score) && Objects.equals(this.confidence, that.confidence); + return Objects.equals(this.score, that.score) + && Objects.equals(this.confidence, that.confidence) + && Arrays.equals(this.attribution, that.attribution); } @Override public int hashCode() { - return Objects.hash(score, confidence); + return Objects.hash(score, confidence, attribution); } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java index 45146188..686621d8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTransportAction.java @@ -363,7 +363,8 @@ private ActionListener onFeatureResponse( featureInResponse, rcfPartitionNum, responseCount, - adID + adID, + detector.getEnabledFeatureIds().size() ); transportService @@ -463,12 +464,12 @@ private void findException(Throwable cause, String adID, AtomicReference rcfResults) { + private CombinedRcfResult getCombinedResult(List rcfResults, int numFeatures) { List rcfResultLib = new ArrayList<>(); for (RCFResultResponse result : rcfResults) { - rcfResultLib.add(new RcfResult(result.getRCFScore(), result.getConfidence(), result.getForestSize())); + rcfResultLib.add(new RcfResult(result.getRCFScore(), result.getConfidence(), result.getForestSize(), result.getAttribution())); } - return modelManager.combineRcfResults(rcfResultLib); + return modelManager.combineRcfResults(rcfResultLib, numFeatures); } void handleExecuteException(Exception ex, ActionListener listener, String adID) { @@ -495,6 +496,7 @@ class RCFActionListener implements ActionListener { private int nodeCount; private final AtomicInteger responseCount; private final String adID; + private int numEnabledFeatures; RCFActionListener( List rcfResults, @@ -508,7 +510,8 @@ class RCFActionListener implements ActionListener { List features, int nodeCount, AtomicInteger responseCount, - String adID + String adID, + int numEnabledFeatures ) { this.rcfResults = rcfResults; this.modelID = modelID; @@ -522,6 +525,7 @@ class RCFActionListener implements ActionListener { this.nodeCount = nodeCount; this.responseCount = responseCount; this.adID = adID; + this.numEnabledFeatures = numEnabledFeatures; } @Override @@ -537,7 +541,7 @@ public void onResponse(RCFResultResponse response) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { if (nodeCount == responseCount.incrementAndGet()) { - handleRCFResults(); + handleRCFResults(numEnabledFeatures); } } } @@ -550,12 +554,12 @@ public void onFailure(Exception e) { LOG.error("Unexpected exception: {} for {}", ex, adID); } finally { if (nodeCount == responseCount.incrementAndGet()) { - handleRCFResults(); + handleRCFResults(numEnabledFeatures); } } } - private void handleRCFResults() { + private void handleRCFResults(int numFeatures) { try { AnomalyDetectionException exception = coldStartIfNoModel(failure, detector); if (exception != null) { @@ -568,7 +572,7 @@ private void handleRCFResults() { return; } - CombinedRcfResult combinedResult = getCombinedResult(rcfResults); + CombinedRcfResult combinedResult = getCombinedResult(rcfResults, numFeatures); double combinedScore = combinedResult.getScore(); final AtomicReference anomalyResultResponse = new AtomicReference<>(); diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultResponse.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultResponse.java index e5938225..74008c88 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultResponse.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultResponse.java @@ -27,14 +27,17 @@ public class RCFResultResponse extends ActionResponse implements ToXContentObjec public static final String RCF_SCORE_JSON_KEY = "rcfScore"; public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String FOREST_SIZE_JSON_KEY = "forestSize"; + public static final String ATTRIBUTION_JSON_KEY = "attribution"; private double rcfScore; private double confidence; private int forestSize; + private double[] attribution; - public RCFResultResponse(double rcfScore, double confidence, int forestSize) { + public RCFResultResponse(double rcfScore, double confidence, int forestSize, double[] attribution) { this.rcfScore = rcfScore; this.confidence = confidence; this.forestSize = forestSize; + this.attribution = attribution; } public RCFResultResponse(StreamInput in) throws IOException { @@ -42,6 +45,7 @@ public RCFResultResponse(StreamInput in) throws IOException { rcfScore = in.readDouble(); confidence = in.readDouble(); forestSize = in.readVInt(); + attribution = in.readDoubleArray(); } public double getRCFScore() { @@ -56,11 +60,21 @@ public int getForestSize() { return forestSize; } + /** + * Returns RCF score attribution. + * + * @return RCF score attribution. + */ + public double[] getAttribution() { + return attribution; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(rcfScore); out.writeDouble(confidence); out.writeVInt(forestSize); + out.writeDoubleArray(attribution); } @Override @@ -69,6 +83,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(RCF_SCORE_JSON_KEY, rcfScore); builder.field(CONFIDENCE_JSON_KEY, confidence); builder.field(FOREST_SIZE_JSON_KEY, forestSize); + builder.field(ATTRIBUTION_JSON_KEY, attribution); builder.endObject(); return builder; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java index c2b28058..d46bdf73 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTransportAction.java @@ -65,7 +65,14 @@ protected void doExecute(Task task, RCFResultRequest request, ActionListener listener - .onResponse(new RCFResultResponse(result.getScore(), result.getConfidence(), result.getForestSize())), + .onResponse( + new RCFResultResponse( + result.getScore(), + result.getConfidence(), + result.getForestSize(), + result.getAttribution() + ) + ), exception -> { LOG.warn(exception); listener.onFailure(exception); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index 57ad3799..0f441fa7 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -75,6 +75,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.returntypes.DiVector; import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; import com.google.gson.Gson; @@ -142,6 +143,9 @@ public class ModelManagerTests { private String failModelId; private String successCheckpoint; private String failCheckpoint; + private double[] attribution; + private double[] point; + private DiVector attributionVec; @Mock private ActionListener rcfResultListener; @@ -171,8 +175,16 @@ public void setup() { modelTtl = Duration.ofHours(1); checkpointInterval = Duration.ofHours(1); shingleSize = 1; + attribution = new double[] { 1, 1 }; + attributionVec = new DiVector(attribution.length); + for (int i = 0; i < attribution.length; i++) { + attributionVec.high[i] = attribution[i]; + attributionVec.low[i] = attribution[i] - 1; + } + point = new double[] { 2 }; - rcf = RandomCutForest.builder().dimensions(numFeatures).sampleSize(numSamples).numberOfTrees(numTrees).build(); + rcf = spy(RandomCutForest.builder().dimensions(numFeatures).sampleSize(numSamples).numberOfTrees(numTrees).build()); + when(rcf.getAnomalyAttribution(point)).thenReturn(attributionVec); when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(10_000_000_000L); @@ -267,28 +279,48 @@ public void getDetectorIdForModelId_throwIllegalArgument_forInvalidId(String mod } private Object[] combineRcfResultsData() { + double[] attribution = new double[] { 1 }; return new Object[] { - new Object[] { asList(), new CombinedRcfResult(0, 0) }, - new Object[] { asList(new RcfResult(0, 0, 0)), new CombinedRcfResult(0, 0) }, - new Object[] { asList(new RcfResult(1, 0, 50)), new CombinedRcfResult(1, 0) }, - new Object[] { asList(new RcfResult(1, 0, 50), new RcfResult(2, 0, 50)), new CombinedRcfResult(1.5, 0) }, + new Object[] { asList(), 1, new CombinedRcfResult(0, 0, new double[0]) }, + new Object[] { asList(new RcfResult(0, 0, 0, new double[0])), 1, new CombinedRcfResult(0, 0, new double[0]) }, + new Object[] { asList(new RcfResult(1, 0, 50, attribution)), 1, new CombinedRcfResult(1, 0, attribution) }, + new Object[] { + asList(new RcfResult(1, 0, 50, attribution), new RcfResult(2, 0, 50, attribution)), + 1, + new CombinedRcfResult(1.5, 0, attribution) }, + new Object[] { + asList(new RcfResult(1, 0, 40, attribution), new RcfResult(2, 0, 60, attribution), new RcfResult(3, 0, 100, attribution)), + 1, + new CombinedRcfResult(2.3, 0, attribution) }, + new Object[] { asList(new RcfResult(0, 1, 100, attribution)), 1, new CombinedRcfResult(0, 1, attribution) }, + new Object[] { asList(new RcfResult(0, 1, 50, attribution)), 1, new CombinedRcfResult(0, 0.5, attribution) }, + new Object[] { asList(new RcfResult(0, 0.5, 1000, attribution)), 1, new CombinedRcfResult(0, 0.5, attribution) }, + new Object[] { + asList(new RcfResult(0, 1, 50, attribution), new RcfResult(0, 0, 50, attribution)), + 1, + new CombinedRcfResult(0, 0.5, attribution) }, new Object[] { - asList(new RcfResult(1, 0, 40), new RcfResult(2, 0, 60), new RcfResult(3, 0, 100)), - new CombinedRcfResult(2.3, 0) }, - new Object[] { asList(new RcfResult(0, 1, 100)), new CombinedRcfResult(0, 1) }, - new Object[] { asList(new RcfResult(0, 1, 50)), new CombinedRcfResult(0, 0.5) }, - new Object[] { asList(new RcfResult(0, 0.5, 1000)), new CombinedRcfResult(0, 0.5) }, - new Object[] { asList(new RcfResult(0, 1, 50), new RcfResult(0, 0, 50)), new CombinedRcfResult(0, 0.5) }, - new Object[] { asList(new RcfResult(0, 0.5, 50), new RcfResult(0, 0.5, 50)), new CombinedRcfResult(0, 0.5) }, + asList(new RcfResult(0, 0.5, 50, attribution), new RcfResult(0, 0.5, 50, attribution)), + 1, + new CombinedRcfResult(0, 0.5, attribution) }, new Object[] { - asList(new RcfResult(0, 1, 20), new RcfResult(0, 1, 30), new RcfResult(0, 0.5, 50)), - new CombinedRcfResult(0, 0.75) }, }; + asList(new RcfResult(0, 1, 20, attribution), new RcfResult(0, 1, 30, attribution), new RcfResult(0, 0.5, 50, attribution)), + 1, + new CombinedRcfResult(0, 0.75, attribution) }, + new Object[] { + asList(new RcfResult(1, 0, 20, new double[] { 0, 0, .5, .5 }), new RcfResult(1, 0, 80, new double[] { 0, .5, .25, .25 })), + 2, + new CombinedRcfResult(1, 0, new double[] { .5, .5 }) }, + new Object[] { + asList(new RcfResult(1, 0, 25, new double[] { 0, 0, 1, .0 }), new RcfResult(1, 0, 75, new double[] { 0, 0, 0, 1 })), + 2, + new CombinedRcfResult(1, 0, new double[] { .25, .75 }) }, }; } @Test @Parameters(method = "combineRcfResultsData") - public void combineRcfResults_returnExpected(List results, CombinedRcfResult expected) { - assertEquals(expected, modelManager.combineRcfResults(results)); + public void combineRcfResults_returnExpected(List results, int numFeatures, CombinedRcfResult expected) { + assertEquals(expected, modelManager.combineRcfResults(results, numFeatures)); } private ImmutableOpenMap createDataNodes(int numDataNodes) { @@ -300,28 +332,29 @@ private ImmutableOpenMap createDataNodes(int numDataNodes } private Object[] getPartitionedForestSizesData() { + RandomCutForest rcf = RandomCutForest.builder().dimensions(1).sampleSize(10).numberOfTrees(100).build(); return new Object[] { // one partition given sufficient large nodes - new Object[] { 100L, 100_000L, createDataNodes(10), pair(1, 100) }, + new Object[] { rcf, 100L, 100_000L, createDataNodes(10), pair(1, 100) }, // two paritions given sufficient medium nodes - new Object[] { 100L, 50_000L, createDataNodes(10), pair(2, 50) }, + new Object[] { rcf, 100L, 50_000L, createDataNodes(10), pair(2, 50) }, // ten partitions given sufficent small nodes - new Object[] { 100L, 10_000L, createDataNodes(10), pair(10, 10) }, + new Object[] { rcf, 100L, 10_000L, createDataNodes(10), pair(10, 10) }, // five double-sized paritions given fewer small nodes - new Object[] { 100L, 10_000L, createDataNodes(5), pair(5, 20) }, + new Object[] { rcf, 100L, 10_000L, createDataNodes(5), pair(5, 20) }, // one large-sized partition given one small node - new Object[] { 100L, 1_000L, createDataNodes(1), pair(1, 100) } }; + new Object[] { rcf, 100L, 1_000L, createDataNodes(1), pair(1, 100) } }; } @Test @Parameters(method = "getPartitionedForestSizesData") public void getPartitionedForestSizes_returnExpected( + RandomCutForest rcf, long totalModelSize, long heapSize, ImmutableOpenMap dataNodes, Entry expected ) { - when(modelManager.estimateModelSize(rcf)).thenReturn(totalModelSize); when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(heapSize); when(nodeFilter.getEligibleDataNodes()).thenReturn(dataNodes.values().toArray(DiscoveryNode.class)); @@ -330,15 +363,17 @@ public void getPartitionedForestSizes_returnExpected( } private Object[] getPartitionedForestSizesLimitExceededData() { + RandomCutForest rcf = RandomCutForest.builder().dimensions(1).sampleSize(10).numberOfTrees(100).build(); return new Object[] { - new Object[] { 101L, 1_000L, createDataNodes(1) }, - new Object[] { 201L, 1_000L, createDataNodes(2) }, - new Object[] { 3001L, 10_000L, createDataNodes(3) } }; + new Object[] { rcf, 101L, 1_000L, createDataNodes(1) }, + new Object[] { rcf, 201L, 1_000L, createDataNodes(2) }, + new Object[] { rcf, 3001L, 10_000L, createDataNodes(3) } }; } @Test(expected = LimitExceededException.class) @Parameters(method = "getPartitionedForestSizesLimitExceededData") public void getPartitionedForestSizes_throwLimitExceeded( + RandomCutForest rcf, long totalModelSize, long heapSize, ImmutableOpenMap dataNodes @@ -379,11 +414,12 @@ public void getRcfResult_returnExpectedToListener() { when(forest.getLambda()).thenReturn(rcfTimeDecay); when(forest.getSampleSize()).thenReturn(numSamples); when(forest.getTotalUpdates()).thenReturn((long) numSamples); + when(forest.getAnomalyAttribution(point)).thenReturn(attributionVec); ActionListener listener = mock(ActionListener.class); modelManager.getRcfResult(detectorId, rcfModelId, point, listener); - RcfResult expected = new RcfResult(score, 0, numTrees); + RcfResult expected = new RcfResult(score, 0, numTrees, new double[] { 0.5, 0.5 }); verify(listener).onResponse(eq(expected)); when(forest.getTotalUpdates()).thenReturn(numSamples + 1L); @@ -825,12 +861,8 @@ public void maintenance_stopInactiveRcfModel() { @Test public void maintenance_keepActiveRcfModel() { - String modelId = "testModelId"; - double[] point = new double[0]; - RandomCutForest forest = mock(RandomCutForest.class); - when(checkpointDao.getModelCheckpoint(modelId)).thenReturn(Optional.of(checkpoint)); - when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); - when(rcfSerde.toJson(forest)).thenReturn(checkpoint); + when(rcfSerde.fromJson(checkpoint)).thenReturn(rcf); + when(rcfSerde.toJson(rcf)).thenReturn(checkpoint); when(clock.instant()).thenReturn(Instant.MIN, Instant.EPOCH, Instant.EPOCH); modelManager.getRcfResult(detectorId, modelId, point, rcfResultListener); @@ -1028,6 +1060,7 @@ public void maintenance_returnExpectedToListener_doNothing() { return null; }).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class)); when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); + when(forest.getAnomalyAttribution(point)).thenReturn(attributionVec); when(rcfSerde.toJson(forest)).thenReturn(checkpoint); when(clock.instant()).thenReturn(Instant.MIN); ActionListener scoreListener = mock(ActionListener.class); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResultTests.java index 7720f929..10a796d2 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/RcfResultTests.java @@ -16,6 +16,9 @@ package com.amazon.opendistroforelasticsearch.ad.ml; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; import junitparams.JUnitParamsRunner; import junitparams.Parameters; @@ -29,12 +32,14 @@ public class RcfResultTests { private double score = 1.; private double confidence = 0; private int forestSize = 10; - private RcfResult rcfResult = new RcfResult(score, confidence, forestSize); + private double[] attribution = new double[] { 1. }; + private RcfResult rcfResult = new RcfResult(score, confidence, forestSize, attribution); @Test public void getters_returnExcepted() { assertEquals(score, rcfResult.getScore(), 1e-8); assertEquals(forestSize, rcfResult.getForestSize()); + assertTrue(Arrays.equals(attribution, rcfResult.getAttribution())); } private Object[] equalsData() { @@ -42,11 +47,12 @@ private Object[] equalsData() { new Object[] { rcfResult, null, false }, new Object[] { rcfResult, rcfResult, true }, new Object[] { rcfResult, 1, false }, - new Object[] { rcfResult, new RcfResult(score, confidence, forestSize), true }, - new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize), false }, - new Object[] { rcfResult, new RcfResult(score, confidence, forestSize + 1), false }, - new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize + 1), false }, - new Object[] { rcfResult, new RcfResult(score, confidence + 1, forestSize), false }, }; + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize, attribution), true }, + new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize, attribution), false }, + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize + 1, attribution), false }, + new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize + 1, attribution), false }, + new Object[] { rcfResult, new RcfResult(score, confidence + 1, forestSize, attribution), false }, + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize, new double[] { 2. }), false }, }; } @Test @@ -57,10 +63,11 @@ public void equals_returnExpected(RcfResult result, Object other, boolean expect private Object[] hashCodeData() { return new Object[] { - new Object[] { rcfResult, new RcfResult(score, confidence, forestSize), true }, - new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize), false }, - new Object[] { rcfResult, new RcfResult(score, confidence, forestSize + 1), false }, - new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize + 1), false }, }; + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize, attribution), true }, + new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize, attribution), false }, + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize + 1, attribution), false }, + new Object[] { rcfResult, new RcfResult(score + 1, confidence, forestSize + 1, attribution), false }, + new Object[] { rcfResult, new RcfResult(score, confidence, forestSize, new double[] { 2. }), false }, }; } @Test diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResultTests.java index f1fd737c..8527957c 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/rcf/CombinedRcfResultTests.java @@ -16,6 +16,9 @@ package com.amazon.opendistroforelasticsearch.ad.ml.rcf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; import junitparams.JUnitParamsRunner; import junitparams.Parameters; @@ -28,12 +31,14 @@ public class CombinedRcfResultTests { private double score = 1.; private double confidence = .5; - private CombinedRcfResult rcfResult = new CombinedRcfResult(score, confidence); + private double[] attribution = new double[] { 1. }; + private CombinedRcfResult rcfResult = new CombinedRcfResult(score, confidence, attribution); @Test public void getters_returnExcepted() { assertEquals(score, rcfResult.getScore(), 1e-8); assertEquals(confidence, rcfResult.getConfidence(), 1e-8); + assertTrue(Arrays.equals(attribution, rcfResult.getAttribution())); } private Object[] equalsData() { @@ -41,10 +46,11 @@ private Object[] equalsData() { new Object[] { rcfResult, null, false }, new Object[] { rcfResult, rcfResult, true }, new Object[] { rcfResult, 1, false }, - new Object[] { rcfResult, new CombinedRcfResult(score, confidence), true }, - new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence), false }, - new Object[] { rcfResult, new CombinedRcfResult(score, confidence + 1), false }, - new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence + 1), false }, }; + new Object[] { rcfResult, new CombinedRcfResult(score, confidence, attribution), true }, + new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score, confidence + 1, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence + 1, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score, confidence, new double[] { 0. }), false }, }; } @Test @@ -55,10 +61,11 @@ public void equals_returnExpected(CombinedRcfResult result, Object other, boolea private Object[] hashCodeData() { return new Object[] { - new Object[] { rcfResult, new CombinedRcfResult(score, confidence), true }, - new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence), false }, - new Object[] { rcfResult, new CombinedRcfResult(score, confidence + 1), false }, - new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence + 1), false }, }; + new Object[] { rcfResult, new CombinedRcfResult(score, confidence, attribution), true }, + new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score, confidence + 1, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score + 1, confidence + 1, attribution), false }, + new Object[] { rcfResult, new CombinedRcfResult(score, confidence, new double[] { 0. }), false }, }; } @Test diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index 37a1037e..b65707a1 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -218,10 +218,10 @@ public void setUp() throws Exception { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(new RcfResult(0.2, 0, 100)); + listener.onResponse(new RcfResult(0.2, 0, 100, new double[] { 1 })); return null; }).when(normalModelManager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); - when(normalModelManager.combineRcfResults(any())).thenReturn(new CombinedRcfResult(0, 1.0d)); + when(normalModelManager.combineRcfResults(any(), anyInt())).thenReturn(new CombinedRcfResult(0, 1.0d, new double[] { 1 })); rcfModelID = "123-rcf-1"; when(normalModelManager.getRcfModelId(any(String.class), anyInt())).thenReturn(rcfModelID); @@ -564,7 +564,7 @@ public T read(StreamInput in) throws IOException { @Override @SuppressWarnings("unchecked") public void handleResponse(T response) { - handler.handleResponse((T) new RCFResultResponse(1, 1, 100)); + handler.handleResponse((T) new RCFResultResponse(1, 1, 100, new double[0])); } @Override @@ -992,7 +992,7 @@ public void testOnFailureNull() throws IOException { threadPool ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null + null, null, null, null, null, null, null, null, null, 0, new AtomicInteger(), null, 1 ); listener.onFailure(null); } @@ -1341,7 +1341,7 @@ public void testNullRCFResult() { threadPool ); AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( - null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null + null, "123-rcf-0", null, "123", null, null, null, null, null, 0, new AtomicInteger(), null, 1 ); listener.onResponse(null); assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java index a11ca716..5bc5e4b7 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/RCFResultTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import org.elasticsearch.action.ActionListener; @@ -58,6 +59,8 @@ public class RCFResultTests extends ESTestCase { Gson gson = new GsonBuilder().create(); + private double[] attribution = new double[] { 1. }; + @SuppressWarnings("unchecked") public void testNormal() { TransportService transportService = new TransportService( @@ -80,7 +83,7 @@ public void testNormal() { ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(new RcfResult(0, 0, 25)); + listener.onResponse(new RcfResult(0, 0, 25, attribution)); return null; }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -93,6 +96,7 @@ public void testNormal() { RCFResultResponse response = future.actionGet(); assertEquals(0, response.getRCFScore(), 0.001); assertEquals(25, response.getForestSize(), 0.001); + assertTrue(Arrays.equals(attribution, response.getAttribution())); } @SuppressWarnings("unchecked") @@ -128,7 +132,7 @@ public void testExecutionException() { } public void testSerialzationResponse() throws IOException { - RCFResultResponse response = new RCFResultResponse(0.3, 0, 26); + RCFResultResponse response = new RCFResultResponse(0.3, 0, 26, attribution); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -136,16 +140,20 @@ public void testSerialzationResponse() throws IOException { RCFResultResponse readResponse = RCFResultAction.INSTANCE.getResponseReader().read(streamInput); assertThat(response.getForestSize(), equalTo(readResponse.getForestSize())); assertThat(response.getRCFScore(), equalTo(readResponse.getRCFScore())); + assertArrayEquals(response.getAttribution(), readResponse.getAttribution(), 1e-6); } public void testJsonResponse() throws IOException, JsonPathNotFoundException { - RCFResultResponse response = new RCFResultResponse(0.3, 0, 26); + RCFResultResponse response = new RCFResultResponse(0.3, 0, 26, attribution); XContentBuilder builder = jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); assertEquals(JsonDeserializer.getDoubleValue(json, RCFResultResponse.RCF_SCORE_JSON_KEY), response.getRCFScore(), 0.001); assertEquals(JsonDeserializer.getDoubleValue(json, RCFResultResponse.FOREST_SIZE_JSON_KEY), response.getForestSize(), 0.001); + assertTrue( + Arrays.equals(JsonDeserializer.getDoubleArrayValue(json, RCFResultResponse.ATTRIBUTION_JSON_KEY), response.getAttribution()) + ); } public void testEmptyID() { @@ -205,7 +213,7 @@ public void testCircuitBreaker() { ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(new RcfResult(0, 0, 25)); + listener.onResponse(new RcfResult(0, 0, 25, attribution)); return null; }).when(manager).getRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); when(breakerService.isOpen()).thenReturn(true);