Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Use callbacks and bug fix #83

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,17 @@ List<String> jacocoExclusions = [
'com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException',
'com.amazon.opendistroforelasticsearch.ad.util.ClientUtil',

'com.amazon.opendistroforelasticsearch.ad.ml.*',
'com.amazon.opendistroforelasticsearch.ad.feature.*',
'com.amazon.opendistroforelasticsearch.ad.dataprocessor.*',
'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner',
'com.amazon.opendistroforelasticsearch.ad.resthandler.RestGetAnomalyResultAction',
'com.amazon.opendistroforelasticsearch.ad.metrics.MetricFactory',
'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices',
'com.amazon.opendistroforelasticsearch.ad.transport.ForwardAction',
'com.amazon.opendistroforelasticsearch.ad.transport.ForwardTransportAction',
'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorAction',
'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorRequest',
'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorResponse',
'com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorTransportAction',
'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction',
'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest',
'com.amazon.opendistroforelasticsearch.ad.transport.DeleteDetectorAction',
'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils'
'com.amazon.opendistroforelasticsearch.ad.transport.CronTransportAction',
'com.amazon.opendistroforelasticsearch.ad.transport.CronRequest',
'com.amazon.opendistroforelasticsearch.ad.transport.ADStatsAction',
'com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorRunner',
'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices',
'com.amazon.opendistroforelasticsearch.ad.util.ParseUtils',
]

jacocoTestCoverageVerification {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultResponse;
import com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction;
import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyResultHandler;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;
import com.amazon.opendistroforelasticsearch.jobscheduler.spi.JobExecutionContext;
import com.amazon.opendistroforelasticsearch.jobscheduler.spi.LockModel;
import com.amazon.opendistroforelasticsearch.jobscheduler.spi.ScheduledJobParameter;
Expand All @@ -39,7 +40,9 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -71,6 +74,7 @@ public class AnomalyDetectorJobRunner implements ScheduledJobRunner {
private Settings settings;
private int maxRetryForEndRunException;
private Client client;
private ClientUtil clientUtil;
private ThreadPool threadPool;
private AnomalyResultHandler anomalyResultHandler;
private ConcurrentHashMap<String, Integer> detectorEndRunExceptionCount;
Expand All @@ -97,6 +101,10 @@ public void setClient(Client client) {
this.client = client;
}

public void setClientUtil(ClientUtil clientUtil) {
this.clientUtil = clientUtil;
}

public void setThreadPool(ThreadPool threadPool) {
this.threadPool = threadPool;
}
Expand Down Expand Up @@ -258,7 +266,7 @@ protected void handleAdException(
) {
String detectorId = jobParameter.getName();
if (exception instanceof EndRunException) {
log.error("EndRunException happened when executed anomaly result action for " + detectorId, exception);
log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception);

if (((EndRunException) exception).isEndNow()) {
// Stop AD job if EndRunException shows we should end job now.
Expand Down Expand Up @@ -349,9 +357,8 @@ private void stopAdJob(String detectorId) {
try {
GetRequest getRequest = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX).id(detectorId);

client.get(getRequest, ActionListener.wrap(response -> {
clientUtil.<GetRequest, GetResponse>asyncRequest(getRequest, client::get, ActionListener.wrap(response -> {
if (response.isExists()) {
String s = response.getSourceAsString();
try (
XContentParser parser = XContentType.JSON
.xContent()
Expand All @@ -374,14 +381,19 @@ private void stopAdJob(String detectorId) {
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE))
.id(detectorId);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
if (indexResponse != null
&& (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) {
log.info("AD Job was disabled by JobRunner for " + detectorId);
} else {
log.warn("Failed to disable AD job for " + detectorId);
}
}, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception)));
clientUtil
.<IndexRequest, IndexResponse>asyncRequest(
indexRequest,
client::index,
ActionListener.wrap(indexResponse -> {
if (indexResponse != null
&& (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) {
log.info("AD Job was disabled by JobRunner for " + detectorId);
} else {
log.warn("Failed to disable AD job for " + detectorId);
}
}, exception -> log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception))
);
}
} catch (IOException e) {
log.error("JobRunner failed to stop detector job " + detectorId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip
private ThreadPool threadPool;
private IndexNameExpressionResolver indexNameExpressionResolver;
private ADStats adStats;
private ClientUtil clientUtil;

static {
SpecialPermission.check();
Expand Down Expand Up @@ -174,6 +175,7 @@ public List<RestHandler> getRestHandlers(
);
AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance();
jobRunner.setClient(client);
jobRunner.setClientUtil(clientUtil);
jobRunner.setThreadPool(threadPool);
jobRunner.setAnomalyResultHandler(anomalyResultHandler);
jobRunner.setSettings(settings);
Expand Down Expand Up @@ -237,7 +239,7 @@ public Collection<Object> createComponents(
Settings settings = environment.settings();
Clock clock = Clock.systemUTC();
Throttler throttler = new Throttler(clock);
ClientUtil clientUtil = new ClientUtil(settings, client, throttler, threadPool);
this.clientUtil = new ClientUtil(settings, client, throttler, threadPool);
IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService);
anomalyDetectionIndices = new AnomalyDetectionIndices(client, clusterService, threadPool, settings, clientUtil);
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,35 @@ public Entry<Integer, Integer> getPartitionedForestSizes(RandomCutForest forest,
return new SimpleImmutableEntry<>(numPartitions, forestSize);
}

/**
* Construct a RCF model and then partition it by forest size.
*
* A RCF model is constructed based on the number of input features.
*
* Then a RCF model is first partitioned into desired size based on heap.
* If there are more partitions than the number of nodes in the cluster,
* the model is partitioned by the number of nodes and verified to
* ensure the size of a partition does not exceed the max size limit based on heap.
*
* @param detectorId ID of the detector with no effects on partitioning
* @param rcfNumFeatures the number of features
* @return a pair of number of partitions and size of a parition (number of trees)
* @throws LimitExceededException when there is no sufficient resouce available
*/
public Entry<Integer, Integer> getPartitionedForestSizes(String detectorId, int rcfNumFeatures) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to refactor this method, I suggest the new api just takes a detector object, which contains all the needed info and simpler to use.

Copy link
Member Author

@kaituo kaituo Apr 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only use use detector id as part of error message. Don't need other detector information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if model manager takes a detector, it can compute the feature dimensions and partitioning so that will be only input needed and that will save client the work to provide a second rcfNumFeatures input. that's why i suggest doing that.

Copy link
Member Author

@kaituo kaituo Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. done

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please take a look at the recent commit: d8ea9cf

return getPartitionedForestSizes(
RandomCutForest
.builder()
.dimensions(rcfNumFeatures)
.sampleSize(rcfNumSamplesInTree)
.numberOfTrees(rcfNumTrees)
.outputAfter(rcfNumSamplesInTree)
.parallelExecutionEnabled(false)
.build(),
detectorId
);
}

/**
* Gets the estimated size of a RCF model.
*
Expand Down Expand Up @@ -545,17 +574,8 @@ public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints) {
int rcfNumFeatures = dataPoints[0].length;

// Create partitioned RCF models
Entry<Integer, Integer> partitionResults = getPartitionedForestSizes(
RandomCutForest
.builder()
.dimensions(rcfNumFeatures)
.sampleSize(rcfNumSamplesInTree)
.numberOfTrees(rcfNumTrees)
.outputAfter(rcfNumSamplesInTree)
.parallelExecutionEnabled(false)
.build(),
anomalyDetector.getDetectorId()
);
Entry<Integer, Integer> partitionResults = getPartitionedForestSizes(anomalyDetector.getDetectorId(), rcfNumFeatures);

int numForests = partitionResults.getKey();
int forestSize = partitionResults.getValue();
double[] scores = new double[dataPoints.length];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,6 @@
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.AbstractMap.SimpleEntry;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -31,10 +30,11 @@
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.client.Client;
Expand All @@ -44,8 +44,6 @@
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;

import com.amazon.randomcutforest.RandomCutForest;

/**
* ADStateManager is used by transport layer to manage AnomalyDetector object
* and the number of partitions for a detector id.
Expand All @@ -56,7 +54,6 @@ public class ADStateManager {
private ConcurrentHashMap<String, Entry<AnomalyDetector, Instant>> currentDetectors;
private ConcurrentHashMap<String, Entry<Integer, Instant>> partitionNumber;
private Client client;
private Random random;
private ModelManager modelManager;
private NamedXContentRegistry xContentRegistry;
private ClientUtil clientUtil;
Expand All @@ -77,7 +74,6 @@ public ADStateManager(
) {
this.currentDetectors = new ConcurrentHashMap<>();
this.client = client;
this.random = new Random();
this.modelManager = modelManager;
this.xContentRegistry = xContentRegistry;
this.partitionNumber = new ConcurrentHashMap<>();
Expand All @@ -91,67 +87,63 @@ public ADStateManager(
/**
* Get the number of RCF model's partition number for detector adID
* @param adID detector id
* @param detector object
* @return the number of RCF model's partition number for adID
* @throws InterruptedException when we cannot get anomaly detector object for adID before timeout
* @throws LimitExceededException when there is no sufficient resource available
*/
public int getPartitionNumber(String adID) throws InterruptedException {
public int getPartitionNumber(String adID, Optional<AnomalyDetector> detector) throws InterruptedException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor. Why not validate detector first and just pass a detector afterwards? saving all the repetitive and unlikely handling of a non-existent detector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. Done.

Entry<Integer, Instant> partitonAndTime = partitionNumber.get(adID);
if (partitonAndTime != null) {
partitonAndTime.setValue(clock.instant());
return partitonAndTime.getKey();
}

Optional<AnomalyDetector> detector = getAnomalyDetector(adID);
if (!detector.isPresent()) {
throw new AnomalyDetectionException(adID, "AnomalyDetector is not found");
}

RandomCutForest forest = RandomCutForest
.builder()
.dimensions(detector.get().getFeatureAttributes().size() * AnomalyDetectorSettings.SHINGLE_SIZE)
.sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE)
.numberOfTrees(AnomalyDetectorSettings.NUM_TREES)
.parallelExecutionEnabled(false)
.build();
int partitionNum = modelManager.getPartitionedForestSizes(forest, adID).getKey();
int partitionNum = modelManager.getPartitionedForestSizes(adID, detector.get().getEnabledFeatureIds().size()).getKey();
partitionNumber.putIfAbsent(adID, new SimpleEntry<>(partitionNum, clock.instant()));
return partitionNum;
}

public Optional<AnomalyDetector> getAnomalyDetector(String adID) {
public void getAnomalyDetector(String adID, ActionListener<Optional<AnomalyDetector>> listener) {
Entry<AnomalyDetector, Instant> detectorAndTime = currentDetectors.get(adID);
if (detectorAndTime != null) {
detectorAndTime.setValue(clock.instant());
return Optional.of(detectorAndTime.getKey());
listener.onResponse(Optional.of(detectorAndTime.getKey()));
return;
}

GetRequest request = new GetRequest(AnomalyDetector.ANOMALY_DETECTORS_INDEX, adID);

Optional<GetResponse> getResponse = clientUtil.<GetRequest, GetResponse>timedRequest(request, LOG, client::get);

return onGetResponse(getResponse, adID);
clientUtil.<GetRequest, GetResponse>asyncRequest(request, client::get, onGetResponse(adID, listener));
}

private Optional<AnomalyDetector> onGetResponse(Optional<GetResponse> asResponse, String adID) {
if (!asResponse.isPresent() || !asResponse.get().isExists()) {
return Optional.empty();
}
private ActionListener<GetResponse> onGetResponse(String adID, ActionListener<Optional<AnomalyDetector>> listener) {
return ActionListener.wrap(response -> {
if (response == null || !response.isExists()) {
listener.onResponse(Optional.empty());
return;
}

GetResponse response = asResponse.get();
String xc = response.getSourceAsString();
LOG.debug("Fetched anomaly detector: {}", xc);

try (XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc)) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId());
currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant()));
return Optional.of(detector);
} catch (Exception t) {
LOG.error("Fail to parse detector {}", adID);
LOG.error("Stack trace:", t);
return Optional.empty();
}
String xc = response.getSourceAsString();
LOG.info("Fetched anomaly detector: {}", xc);

try (
XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc)
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId());
currentDetectors.put(adID, new SimpleEntry<>(detector, clock.instant()));
listener.onResponse(Optional.of(detector));
} catch (Exception t) {
LOG.error("Fail to parse detector {}", adID);
LOG.error("Stack trace:", t);
listener.onResponse(Optional.empty());
}
}, listener::onFailure);
}

/**
Expand Down
Loading