Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

add anomaly feature attribution to model output #232

Merged
merged 2 commits into from
Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer;
import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.serialize.RandomCutForestSerDe;
import com.google.gson.Gson;

Expand Down Expand Up @@ -206,28 +207,48 @@ public ModelManager(
*
* Final RCF score is calculated by averaging scores weighted by model size (number of trees).
* Confidence is the weighted average of confidence with confidence for missing models being 0.
* Attribution is normalized weighted average for the most recent feature dimensions.
*
* @param rcfResults RCF results from partitioned models
* @param numFeatures number of features for attribution
* @return combined RCF result
*/
public CombinedRcfResult combineRcfResults(List<RcfResult> rcfResults) {
public CombinedRcfResult combineRcfResults(List<RcfResult> rcfResults, int numFeatures) {
CombinedRcfResult combinedResult = null;
if (rcfResults.isEmpty()) {
combinedResult = new CombinedRcfResult(0, 0);
combinedResult = new CombinedRcfResult(0, 0, new double[0]);
} else {
int totalForestSize = rcfResults.stream().mapToInt(RcfResult::getForestSize).sum();
if (totalForestSize == 0) {
combinedResult = new CombinedRcfResult(0, 0);
combinedResult = new CombinedRcfResult(0, 0, new double[0]);
} else {
double score = rcfResults.stream().mapToDouble(r -> r.getScore() * r.getForestSize()).sum() / totalForestSize;
double confidence = rcfResults.stream().mapToDouble(r -> r.getConfidence() * r.getForestSize()).sum() / Math
.max(rcfNumTrees, totalForestSize);
combinedResult = new CombinedRcfResult(score, confidence);
double[] attribution = combineAttribution(rcfResults, numFeatures, totalForestSize);
combinedResult = new CombinedRcfResult(score, confidence, combineAttribution(rcfResults, numFeatures, totalForestSize));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] It looks like the variable attribution was meant to be used in line 229, but didn't get used and instead combineAttribution was invoked again with identical args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. I fixed it in the new commit.

}
}
return combinedResult;
}

private double[] combineAttribution(List<RcfResult> rcfResults, int numFeatures, int totalForestSize) {
double[] combined = new double[numFeatures];
double sum = 0;
for (RcfResult result : rcfResults) {
double[] attribution = result.getAttribution();
for (int i = 0; i < numFeatures; i++) {
double attr = attribution[attribution.length - numFeatures + i] * result.getForestSize() / totalForestSize;
combined[i] += attr;
sum += attr;
}
}
for (int i = 0; i < numFeatures; i++) {
combined[i] /= sum;
}
return combined;
}

/**
* Gets the detector id from the model id.
*
Expand Down Expand Up @@ -349,13 +370,25 @@ public void getRcfResult(String detectorId, String modelId, double[] point, Acti
}

private void getRcfResult(ModelState<RandomCutForest> modelState, double[] point, ActionListener<RcfResult> listener) {
modelState.setLastUsedTime(clock.instant());

RandomCutForest rcf = modelState.getModel();
double score = rcf.getAnomalyScore(point);
double confidence = computeRcfConfidence(rcf);
int forestSize = rcf.getNumberOfTrees();
double[] attribution = getAnomalyAttribution(rcf, point);
rcf.update(point);
modelState.setLastUsedTime(clock.instant());
listener.onResponse(new RcfResult(score, confidence, forestSize));
listener.onResponse(new RcfResult(score, confidence, forestSize, attribution));
}

private double[] getAnomalyAttribution(RandomCutForest rcf, double[] point) {
DiVector vec = rcf.getAnomalyAttribution(point);
vec.renormalize(1d);
double[] attribution = new double[vec.getDimensions()];
for (int i = 0; i < attribution.length; i++) {
attribution[i] = vec.getHighLowSum(i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Questions:
First, how to interpret both high and low are non-zero? Is it really high or low? Does it mean RCF trees think the value can be both higher or lower than the recently observed data trends for that column? Do we need a majority win rule to say it is actually high or low?
Second, when doing a high low sum, we lose direction. Is there any way to preserve the direction?
Third, when users see two features' attribution like x: 1% and y 99%, it tells users y is the place anomaly happens. It might as well not to show x's 1%. I feel an attribution score less than 1/d (d is the number of features) is not useful to users. Any comments on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. for a given sum, the larger value indicates the direction/relative position. Note even data within normal range has non-zero values for both.
  2. since the current ux design only shows feature contribution, only contribution is computed here. The direction is not lost. It can be added when it's needed.
  3. 1/99 or 0/100 probably won't make a difference for users. In general, additional rules should be avoided for simplicity and correctness as they might introduce their own problems. As an extreme example, the contribution from two features is 49/51, if 49 is omitted, after normalization the result could be 0/100.

}
return attribution;
}

private Optional<ModelState<RandomCutForest>> restoreCheckpoint(Optional<String> rcfCheckpoint, String modelId, String detectorId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.amazon.opendistroforelasticsearch.ad.ml;

import java.util.Arrays;
import java.util.Objects;

/**
Expand All @@ -25,18 +26,21 @@ public class RcfResult {
private final double score;
private final double confidence;
private final int forestSize;
private final double[] attribution;

/**
* Constructor with all arguments.
*
* @param score RCF score
* @param confidence RCF confidence
* @param forestSize number of RCF trees used for the score
* @param attribution anomaly score attribution
*/
public RcfResult(double score, double confidence, int forestSize) {
public RcfResult(double score, double confidence, int forestSize, double[] attribution) {
this.score = score;
this.confidence = confidence;
this.forestSize = forestSize;
this.attribution = attribution;
}

/**
Expand Down Expand Up @@ -66,6 +70,15 @@ public int getForestSize() {
return forestSize;
}

/**
* Returns anomaly score attribution.
*
* @return anomaly score attribution
*/
public double[] getAttribution() {
return attribution;
}

@Override
public boolean equals(Object o) {
if (this == o)
Expand All @@ -75,11 +88,12 @@ public boolean equals(Object o) {
RcfResult that = (RcfResult) o;
return Objects.equals(this.score, that.score)
&& Objects.equals(this.confidence, that.confidence)
&& Objects.equals(this.forestSize, that.forestSize);
&& Objects.equals(this.forestSize, that.forestSize)
&& Arrays.equals(this.attribution, that.attribution);
}

@Override
public int hashCode() {
return Objects.hash(score, confidence, forestSize);
return Objects.hash(score, confidence, forestSize, attribution);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.amazon.opendistroforelasticsearch.ad.ml.rcf;

import java.util.Arrays;
import java.util.Objects;

/**
Expand All @@ -24,16 +25,19 @@ public class CombinedRcfResult {

private final double score;
private final double confidence;
private final double[] attribution;

/**
* Constructor with all arguments.
*
* @param score combined RCF score
* @param confidence confidence of the score
* @param attribution score attribution normalized to 1
*/
public CombinedRcfResult(double score, double confidence) {
public CombinedRcfResult(double score, double confidence, double[] attribution) {
this.score = score;
this.confidence = confidence;
this.attribution = attribution;
}

/**
Expand All @@ -54,18 +58,29 @@ public double getConfidence() {
return confidence;
}

/**
* Return score attribution normalized to 1.
*
* @return score attribution normalized to 1
*/
public double[] getAttribution() {
return attribution;
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
CombinedRcfResult that = (CombinedRcfResult) o;
return Objects.equals(this.score, that.score) && Objects.equals(this.confidence, that.confidence);
return Objects.equals(this.score, that.score)
&& Objects.equals(this.confidence, that.confidence)
&& Arrays.equals(this.attribution, that.attribution);
}

@Override
public int hashCode() {
return Objects.hash(score, confidence);
return Objects.hash(score, confidence, attribution);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ private ActionListener<SinglePointFeatures> onFeatureResponse(
featureInResponse,
rcfPartitionNum,
responseCount,
adID
adID,
detector.getEnabledFeatureIds().size()
);

transportService
Expand Down Expand Up @@ -463,12 +464,12 @@ private void findException(Throwable cause, String adID, AtomicReference<Anomaly
}
}

private CombinedRcfResult getCombinedResult(List<RCFResultResponse> rcfResults) {
private CombinedRcfResult getCombinedResult(List<RCFResultResponse> rcfResults, int numFeatures) {
List<RcfResult> rcfResultLib = new ArrayList<>();
for (RCFResultResponse result : rcfResults) {
rcfResultLib.add(new RcfResult(result.getRCFScore(), result.getConfidence(), result.getForestSize()));
rcfResultLib.add(new RcfResult(result.getRCFScore(), result.getConfidence(), result.getForestSize(), result.getAttribution()));
}
return modelManager.combineRcfResults(rcfResultLib);
return modelManager.combineRcfResults(rcfResultLib, numFeatures);
}

void handleExecuteException(Exception ex, ActionListener<AnomalyResultResponse> listener, String adID) {
Expand All @@ -495,6 +496,7 @@ class RCFActionListener implements ActionListener<RCFResultResponse> {
private int nodeCount;
private final AtomicInteger responseCount;
private final String adID;
private int numEnabledFeatures;

RCFActionListener(
List<RCFResultResponse> rcfResults,
Expand All @@ -508,7 +510,8 @@ class RCFActionListener implements ActionListener<RCFResultResponse> {
List<FeatureData> features,
int nodeCount,
AtomicInteger responseCount,
String adID
String adID,
int numEnabledFeatures
) {
this.rcfResults = rcfResults;
this.modelID = modelID;
Expand All @@ -522,6 +525,7 @@ class RCFActionListener implements ActionListener<RCFResultResponse> {
this.nodeCount = nodeCount;
this.responseCount = responseCount;
this.adID = adID;
this.numEnabledFeatures = numEnabledFeatures;
}

@Override
Expand All @@ -537,7 +541,7 @@ public void onResponse(RCFResultResponse response) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
if (nodeCount == responseCount.incrementAndGet()) {
handleRCFResults();
handleRCFResults(numEnabledFeatures);
}
}
}
Expand All @@ -550,12 +554,12 @@ public void onFailure(Exception e) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
if (nodeCount == responseCount.incrementAndGet()) {
handleRCFResults();
handleRCFResults(numEnabledFeatures);
}
}
}

private void handleRCFResults() {
private void handleRCFResults(int numFeatures) {
try {
AnomalyDetectionException exception = coldStartIfNoModel(failure, detector);
if (exception != null) {
Expand All @@ -568,7 +572,7 @@ private void handleRCFResults() {
return;
}

CombinedRcfResult combinedResult = getCombinedResult(rcfResults);
CombinedRcfResult combinedResult = getCombinedResult(rcfResults, numFeatures);
double combinedScore = combinedResult.getScore();

final AtomicReference<AnomalyResultResponse> anomalyResultResponse = new AtomicReference<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,25 @@ public class RCFResultResponse extends ActionResponse implements ToXContentObjec
public static final String RCF_SCORE_JSON_KEY = "rcfScore";
public static final String CONFIDENCE_JSON_KEY = "confidence";
public static final String FOREST_SIZE_JSON_KEY = "forestSize";
public static final String ATTRIBUTION_JSON_KEY = "attribution";
private double rcfScore;
private double confidence;
private int forestSize;
private double[] attribution;

public RCFResultResponse(double rcfScore, double confidence, int forestSize) {
public RCFResultResponse(double rcfScore, double confidence, int forestSize, double[] attribution) {
this.rcfScore = rcfScore;
this.confidence = confidence;
this.forestSize = forestSize;
this.attribution = attribution;
}

public RCFResultResponse(StreamInput in) throws IOException {
super(in);
rcfScore = in.readDouble();
confidence = in.readDouble();
forestSize = in.readVInt();
attribution = in.readDoubleArray();
}

public double getRCFScore() {
Expand All @@ -56,11 +60,21 @@ public int getForestSize() {
return forestSize;
}

/**
* Returns RCF score attribution.
*
* @return RCF score attribution.
*/
public double[] getAttribution() {
return attribution;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(rcfScore);
out.writeDouble(confidence);
out.writeVInt(forestSize);
out.writeDoubleArray(attribution);
}

@Override
Expand All @@ -69,6 +83,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(RCF_SCORE_JSON_KEY, rcfScore);
builder.field(CONFIDENCE_JSON_KEY, confidence);
builder.field(FOREST_SIZE_JSON_KEY, forestSize);
builder.field(ATTRIBUTION_JSON_KEY, attribution);
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ protected void doExecute(Task task, RCFResultRequest request, ActionListener<RCF
ActionListener
.wrap(
result -> listener
.onResponse(new RCFResultResponse(result.getScore(), result.getConfidence(), result.getForestSize())),
.onResponse(
new RCFResultResponse(
result.getScore(),
result.getConfidence(),
result.getForestSize(),
result.getAttribution()
)
),
exception -> {
LOG.warn(exception);
listener.onFailure(exception);
Expand Down
Loading