Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
add async getRcfResult (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored Mar 19, 2020
1 parent 6a8883b commit 8272123
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import com.google.gson.Gson;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.monitor.jvm.JvmService;

Expand Down Expand Up @@ -319,6 +320,58 @@ public RcfResult getRcfResult(String detectorId, String modelId, double[] point)
return new RcfResult(score, confidence, forestSize);
}

/**
* Returns to listener the RCF anomaly result using the specified model.
*
* @param detectorId ID of the detector
* @param modelId ID of the model to score the point
* @param point features of the data point
* @param listener onResponse is called with RCF result for the input point, including a score
* onFailure is called with ResourceNotFoundException when the model is not found
* onFailure is called with LimitExceededException when a limit is exceeded for the model
*/
public void getRcfResult(String detectorId, String modelId, double[] point, ActionListener<RcfResult> listener) {
if (forests.containsKey(modelId)) {
getRcfResult(forests.get(modelId), point, listener);
} else {
checkpointDao
.getModelCheckpoint(
modelId,
ActionListener
.wrap(checkpoint -> processRcfCheckpoint(checkpoint, modelId, detectorId, point, listener), listener::onFailure)
);
}
}

private void getRcfResult(ModelState<RandomCutForest> modelState, double[] point, ActionListener<RcfResult> listener) {
RandomCutForest rcf = modelState.getModel();
double score = rcf.getAnomalyScore(point);
double confidence = computeRcfConfidence(rcf);
int forestSize = rcf.getNumberOfTrees();
rcf.update(point);
modelState.setLastUsedTime(clock.instant());
listener.onResponse(new RcfResult(score, confidence, forestSize));
}

private void processRcfCheckpoint(
Optional<String> rcfCheckpoint,
String modelId,
String detectorId,
double[] point,
ActionListener<RcfResult> listener
) {
Optional<ModelState<RandomCutForest>> model = rcfCheckpoint
.map(checkpoint -> AccessController.doPrivileged((PrivilegedAction<RandomCutForest>) () -> rcfSerde.fromJson(checkpoint)))
.filter(rcf -> isHostingAllowed(detectorId, rcf))
.map(rcf -> new ModelState<>(rcf, modelId, detectorId, ModelType.RCF.getName(), clock.instant()));
if (model.isPresent()) {
forests.put(modelId, model.get());
getRcfResult(model.get(), point, listener);
} else {
throw new ResourceNotFoundException(detectorId, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + modelId);
}
}

/**
* Gets the result using the specified thresholding model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.ImmutableOpenMap;
Expand All @@ -47,6 +48,7 @@
import org.junit.runner.RunWith;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand All @@ -65,8 +67,10 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -130,6 +134,7 @@ public class ModelManagerTests {
private String modelId;
private String rcfModelId;
private String thresholdModelId;
private String checkpoint;

@Before
public void setup() {
Expand Down Expand Up @@ -188,6 +193,7 @@ public void setup() {
modelId = "modelId";
rcfModelId = "detectorId_model_rcf_1";
thresholdModelId = "detectorId_model_threshold";
checkpoint = "testcheckpoint";
}

private Object[] getDetectorIdForModelIdData() {
Expand Down Expand Up @@ -363,6 +369,72 @@ public void getRcfResult_throwLimitExceeded_whenHeapLimitReached() {
modelManager.getRcfResult(detectorId, modelId, new double[0]);
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_returnExpectedToListener() {
double[] point = new double[0];
RandomCutForest forest = mock(RandomCutForest.class);
double score = 11.;

doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(forest.getAnomalyScore(point)).thenReturn(score);
when(forest.getNumberOfTrees()).thenReturn(numTrees);
when(forest.getLambda()).thenReturn(rcfTimeDecay);
when(forest.getSampleSize()).thenReturn(numSamples);
when(forest.getTotalUpdates()).thenReturn((long) numSamples);

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, listener);

RcfResult expected = new RcfResult(score, 0, numTrees);
verify(listener).onResponse(eq(expected));

when(forest.getTotalUpdates()).thenReturn(numSamples + 1L);
listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, point, listener);

ArgumentCaptor<RcfResult> responseCaptor = ArgumentCaptor.forClass(RcfResult.class);
verify(listener).onResponse(responseCaptor.capture());
assertEquals(0.091353632, responseCaptor.getValue().getConfidence(), 1e-6);
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_throwToListener_whenNoCheckpoint() {
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.empty());
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0], listener);

verify(listener).onFailure(any(ResourceNotFoundException.class));
}

@Test
@SuppressWarnings("unchecked")
public void getRcfResult_throwToListener_whenHeapLimitExceed() {
doAnswer(invocation -> {
ActionListener<Optional<String>> listener = invocation.getArgument(1);
listener.onResponse(Optional.of(checkpoint));
return null;
}).when(checkpointDao).getModelCheckpoint(eq(rcfModelId), any(ActionListener.class));
when(rcfSerde.fromJson(checkpoint)).thenReturn(rcf);
when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(1_000L);

ActionListener<RcfResult> listener = mock(ActionListener.class);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0], listener);

verify(listener).onFailure(any(LimitExceededException.class));
}

@Test
public void getThresholdingResult_returnExpected() {
String modelId = "testModelId";
Expand Down

0 comments on commit 8272123

Please sign in to comment.