-
Notifications
You must be signed in to change notification settings - Fork 36
add anomaly feature attribution to model output #232
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; | ||
import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; | ||
import com.amazon.randomcutforest.RandomCutForest; | ||
import com.amazon.randomcutforest.returntypes.DiVector; | ||
import com.amazon.randomcutforest.serialize.RandomCutForestSerDe; | ||
import com.google.gson.Gson; | ||
|
||
|
@@ -206,28 +207,48 @@ public ModelManager( | |
* | ||
* Final RCF score is calculated by averaging scores weighted by model size (number of trees). | ||
* Confidence is the weighted average of confidence with confidence for missing models being 0. | ||
* Attribution is normalized weighted average for the most recent feature dimensions. | ||
* | ||
* @param rcfResults RCF results from partitioned models | ||
* @param numFeatures number of features for attribution | ||
* @return combined RCF result | ||
*/ | ||
public CombinedRcfResult combineRcfResults(List<RcfResult> rcfResults) { | ||
public CombinedRcfResult combineRcfResults(List<RcfResult> rcfResults, int numFeatures) { | ||
CombinedRcfResult combinedResult = null; | ||
if (rcfResults.isEmpty()) { | ||
combinedResult = new CombinedRcfResult(0, 0); | ||
combinedResult = new CombinedRcfResult(0, 0, new double[0]); | ||
} else { | ||
int totalForestSize = rcfResults.stream().mapToInt(RcfResult::getForestSize).sum(); | ||
if (totalForestSize == 0) { | ||
combinedResult = new CombinedRcfResult(0, 0); | ||
combinedResult = new CombinedRcfResult(0, 0, new double[0]); | ||
} else { | ||
double score = rcfResults.stream().mapToDouble(r -> r.getScore() * r.getForestSize()).sum() / totalForestSize; | ||
double confidence = rcfResults.stream().mapToDouble(r -> r.getConfidence() * r.getForestSize()).sum() / Math | ||
.max(rcfNumTrees, totalForestSize); | ||
combinedResult = new CombinedRcfResult(score, confidence); | ||
double[] attribution = combineAttribution(rcfResults, numFeatures, totalForestSize); | ||
combinedResult = new CombinedRcfResult(score, confidence, combineAttribution(rcfResults, numFeatures, totalForestSize)); | ||
} | ||
} | ||
return combinedResult; | ||
} | ||
|
||
private double[] combineAttribution(List<RcfResult> rcfResults, int numFeatures, int totalForestSize) { | ||
double[] combined = new double[numFeatures]; | ||
double sum = 0; | ||
for (RcfResult result : rcfResults) { | ||
double[] attribution = result.getAttribution(); | ||
for (int i = 0; i < numFeatures; i++) { | ||
double attr = attribution[attribution.length - numFeatures + i] * result.getForestSize() / totalForestSize; | ||
combined[i] += attr; | ||
sum += attr; | ||
} | ||
} | ||
for (int i = 0; i < numFeatures; i++) { | ||
combined[i] /= sum; | ||
} | ||
return combined; | ||
} | ||
|
||
/** | ||
* Gets the detector id from the model id. | ||
* | ||
|
@@ -349,13 +370,25 @@ public void getRcfResult(String detectorId, String modelId, double[] point, Acti | |
} | ||
|
||
private void getRcfResult(ModelState<RandomCutForest> modelState, double[] point, ActionListener<RcfResult> listener) { | ||
modelState.setLastUsedTime(clock.instant()); | ||
|
||
RandomCutForest rcf = modelState.getModel(); | ||
double score = rcf.getAnomalyScore(point); | ||
double confidence = computeRcfConfidence(rcf); | ||
int forestSize = rcf.getNumberOfTrees(); | ||
double[] attribution = getAnomalyAttribution(rcf, point); | ||
rcf.update(point); | ||
modelState.setLastUsedTime(clock.instant()); | ||
listener.onResponse(new RcfResult(score, confidence, forestSize)); | ||
listener.onResponse(new RcfResult(score, confidence, forestSize, attribution)); | ||
} | ||
|
||
private double[] getAnomalyAttribution(RandomCutForest rcf, double[] point) { | ||
DiVector vec = rcf.getAnomalyAttribution(point); | ||
vec.renormalize(1d); | ||
double[] attribution = new double[vec.getDimensions()]; | ||
for (int i = 0; i < attribution.length; i++) { | ||
attribution[i] = vec.getHighLowSum(i); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Questions: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
return attribution; | ||
} | ||
|
||
private Optional<ModelState<RandomCutForest>> restoreCheckpoint(Optional<String> rcfCheckpoint, String modelId, String detectorId) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[minor] It looks like the variable
attribution
was meant to be used in line 229, but didn't get used and insteadcombineAttribution
was invoked again with identical args.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. I fixed it in the new commit.