Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix for stats API (#287)
Browse files Browse the repository at this point in the history
* Add more tests

* Fix for stats API

This PR fixes two issues for the stats API.

First, we didn't propagate multi-entity detectors' models execution exceptions for the remote invocation.  This problem may impact stats' API ability to report the total failures count and thus hide an issue we should have reported during monitoring.  This PR fixes the issue by collecting model host nodes' exceptions from coordinating nodes.

Second, we didn't show active multi-entity detectors' models information on stats API.  This PR places this information into stats API output.

This PR also adds unit tests for ModelManager.

Testing done:
1. added unit tests
2. manually verified the two issues are resolved.
  • Loading branch information
kaituo authored Oct 22, 2020
1 parent ad02fd5 commit 2f18231
Show file tree
Hide file tree
Showing 19 changed files with 806 additions and 87 deletions.
3 changes: 0 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,6 @@ List<String> jacocoExclusions = [
'com.amazon.opendistroforelasticsearch.ad.transport.SearchAnomalyResultTransportAction*',

// TODO: hc caused coverage to drop
//'com.amazon.opendistroforelasticsearch.ad.ml.ModelManager',
'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction',
'com.amazon.opendistroforelasticsearch.ad.transport.AnomalyResultTransportAction.EntityResultListener',
'com.amazon.opendistroforelasticsearch.ad.NodeStateManager',
'com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler',
'com.amazon.opendistroforelasticsearch.ad.transport.EntityProfileTransportAction*',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ public Collection<Object> createComponents(
.<String, ADStat<?>>builder()
.put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier()))
.put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier()))
.put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager)))
.put(StatNames.MODEL_INFORMATION.getName(), new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider)))
.put(
StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(),
new ADStat<>(true, new IndexStatusSupplier(indexUtils, AnomalyDetector.ANOMALY_DETECTORS_INDEX))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Comparator;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
Expand Down Expand Up @@ -525,4 +527,8 @@ public boolean expired(Duration stateTtl) {
public String getDetectorId() {
return detectorId;
}

public List<ModelState<?>> getAllModels() {
return items.values().stream().collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package com.amazon.opendistroforelasticsearch.ad.caching;

import java.util.List;

import com.amazon.opendistroforelasticsearch.ad.CleanState;
import com.amazon.opendistroforelasticsearch.ad.MaintenanceState;
import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel;
Expand Down Expand Up @@ -72,4 +74,11 @@ public interface EntityCache extends MaintenanceState, CleanState {
* @return RCF model total updates of specific entity
*/
long getTotalUpdates(String detectorId, String entityModelId);

/**
* Gets modelStates of all model hosted on a node
*
* @return list of modelStates
*/
List<ModelState<?>> getAllModels();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -554,4 +556,16 @@ public int getTotalActiveEntities() {
activeEnities.values().stream().forEach(cacheBuffer -> { total.addAndGet(cacheBuffer.getActiveEntities()); });
return total.get();
}

/**
* Gets modelStates of all model hosted on a node
*
* @return list of modelStates
*/
@Override
public List<ModelState<?>> getAllModels() {
List<ModelState<?>> states = new ArrayList<>();
activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels()));
return states;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,8 @@ public void getFeaturesByEntities(
new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, termsListener, false)
);

} catch (IOException e) {
throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true);
} catch (Exception e) {
throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
Expand Down Expand Up @@ -1022,7 +1023,9 @@ public void processEntityCheckpoint(
modelState.setLastCheckpointTime(clock.instant().minus(checkpointInterval));
}

assert (modelState.getModel() != null);
if (modelState.getModel() == null) {
modelState.setModel(new EntityModel(modelId, new ArrayDeque<>(), null, null));
}
maybeTrainBeforeScore(modelState, entityName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;

/**
* ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains
*/
public class ModelsOnNodeSupplier implements Supplier<List<Map<String, Object>>> {
private ModelManager modelManager;
private CacheProvider cache;

/**
* Set that contains the model stats that should be exposed.
Expand All @@ -45,16 +48,18 @@ public class ModelsOnNodeSupplier implements Supplier<List<Map<String, Object>>>
* Constructor
*
* @param modelManager object that manages the model partitions hosted on the node
* @param cache object that manages multi-entity detectors' models
*/
public ModelsOnNodeSupplier(ModelManager modelManager) {
public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache) {
this.modelManager = modelManager;
this.cache = cache;
}

@Override
public List<Map<String, Object>> get() {
List<Map<String, Object>> values = new ArrayList<>();
modelManager
.getAllModels()
Stream
.concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream())
.forEach(
modelState -> values
.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -104,6 +105,8 @@ public class AnomalyResultTransportAction extends HandledTransportAction<ActionR
.getExceptionName(new LimitExceededException("", ""));
static final String NULL_RESPONSE = "Received null response from";
static final String BUG_RESPONSE = "We might have bugs.";
static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: ";
static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes.";

private final TransportService transportService;
private final NodeStateManager stateManager;
Expand Down Expand Up @@ -213,7 +216,6 @@ public AnomalyResultTransportAction(
@Override
protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<AnomalyResultResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {

AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest);
ActionListener<AnomalyResultResponse> original = listener;
listener = ActionListener.wrap(original::onResponse, e -> {
Expand All @@ -233,7 +235,6 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener<
listener.onFailure(new LimitExceededException(adID, CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false));
return;
}

try {
stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request));
} catch (Exception ex) {
Expand Down Expand Up @@ -297,7 +298,7 @@ private ActionListener<Optional<AnomalyDetector>> onGetDetector(
)
);
} else {
entityFeatures
Set<Entry<DiscoveryNode, Map<String, double[]>>> node2Entities = entityFeatures
.entrySet()
.stream()
.collect(
Expand All @@ -307,26 +308,29 @@ private ActionListener<Optional<AnomalyDetector>> onGetDetector(
Collectors.toMap(Entry::getKey, Entry::getValue)
)
)
.entrySet()
.stream()
.forEach(nodeEntity -> {
DiscoveryNode node = nodeEntity.getKey();
transportService
.sendRequest(
node,
EntityResultAction.NAME,
new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime),
this.option,
new ActionListenerResponseHandler<>(
new EntityResultListener(node.getId(), adID),
AcknowledgedResponse::new,
ThreadPool.Names.SAME
)
);
});
.entrySet();

int nodeCount = node2Entities.size();
AtomicInteger responseCount = new AtomicInteger();

final AtomicReference<AnomalyDetectionException> failure = new AtomicReference<>();
node2Entities.stream().forEach(nodeEntity -> {
DiscoveryNode node = nodeEntity.getKey();
transportService
.sendRequest(
node,
EntityResultAction.NAME,
new EntityResultRequest(adID, nodeEntity.getValue(), dataStartTime, dataEndTime),
this.option,
new ActionListenerResponseHandler<>(
new EntityResultListener(node.getId(), adID, responseCount, nodeCount, failure, listener),
AcknowledgedResponse::new,
ThreadPool.Names.SAME
)
);
});
}

listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList<FeatureData>()));
}, exception -> handleFailure(exception, listener, adID));

threadPool
Expand Down Expand Up @@ -482,7 +486,7 @@ private ActionListener<SinglePointFeatures> onFeatureResponse(

private void handleFailure(Exception exception, ActionListener<AnomalyResultResponse> listener, String adID) {
if (exception instanceof IndexNotFoundException) {
listener.onFailure(new EndRunException(adID, "Having trouble querying data: " + exception.getMessage(), true));
listener.onFailure(new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), true));
} else if (exception instanceof EndRunException) {
// invalid feature query
listener.onFailure(exception);
Expand Down Expand Up @@ -555,7 +559,7 @@ private void findException(Throwable cause, String adID, AtomicReference<Anomaly
&& causeException.getMessage().contains(CommonName.CHECKPOINT_INDEX_NAME))) {
failure.set(new ResourceNotFoundException(adID, causeException.getMessage()));
} else if (ExceptionUtil.isException(causeException, LimitExceededException.class, LIMIT_EXCEEDED_EXCEPTION_NAME_UNDERSCORE)) {
failure.set(new LimitExceededException(adID, causeException.getMessage()));
failure.set(new LimitExceededException(adID, causeException.getMessage(), false));
} else if (causeException instanceof ElasticsearchTimeoutException) {
// we can have ElasticsearchTimeoutException when a node tries to load RCF or
// threshold model
Expand Down Expand Up @@ -787,7 +791,7 @@ private void handleThresholdResult() {
}

private void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference<AnomalyDetectionException> failure) {
LOG.error(new ParameterizedMessage("Received an error from node {} while fetching anomaly grade for {}", nodeID, adID), e);
LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e);
if (e == null) {
return;
}
Expand All @@ -801,6 +805,8 @@ private void handlePredictionFailure(Exception e, String adID, String nodeID, At

/**
* Check if the input exception indicates connection issues.
* During blue-green deployment, we may see ActionNotFoundTransportException.
* Count that as connection issue and isolate that node if it continues to happen.
*
* @param e exception
* @return true if we get disconnected from the node or the node is not in the
Expand All @@ -811,7 +817,8 @@ private boolean hasConnectionIssue(Throwable e) {
|| e instanceof NodeClosedException
|| e instanceof ReceiveTimeoutTransportException
|| e instanceof NodeNotConnectedException
|| e instanceof ConnectException;
|| e instanceof ConnectException
|| e instanceof ActionNotFoundTransportException;
}

private void handleConnectionException(String node) {
Expand Down Expand Up @@ -1015,18 +1022,45 @@ private Optional<AnomalyDetectionException> coldStartIfNoCheckPoint(AnomalyDetec
class EntityResultListener implements ActionListener<AcknowledgedResponse> {
private String nodeId;
private final String adID;
private AtomicInteger responseCount;
private int nodeCount;
private ActionListener<AnomalyResultResponse> listener;
private List<AcknowledgedResponse> ackResponses;
private AtomicReference<AnomalyDetectionException> failure;

EntityResultListener(String nodeId, String adID) {
EntityResultListener(
String nodeId,
String adID,
AtomicInteger responseCount,
int nodeCount,
AtomicReference<AnomalyDetectionException> failure,
ActionListener<AnomalyResultResponse> listener
) {
this.nodeId = nodeId;
this.adID = adID;
this.responseCount = responseCount;
this.nodeCount = nodeCount;
this.failure = failure;
this.listener = listener;
this.ackResponses = new ArrayList<>();
}

@Override
public void onResponse(AcknowledgedResponse response) {
stateManager.resetBackpressureCounter(nodeId);
if (response.isAcknowledged() == false) {
LOG.error("Cannot send entities' features to {} for {}", nodeId, adID);
stateManager.addPressure(nodeId);
try {
stateManager.resetBackpressureCounter(nodeId);
if (response.isAcknowledged() == false) {
LOG.error("Cannot send entities' features to {} for {}", nodeId, adID);
stateManager.addPressure(nodeId);
} else {
ackResponses.add(response);
}
} catch (Exception ex) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
if (nodeCount == responseCount.incrementAndGet()) {
handleEntityResponses();
}
}
}

Expand All @@ -1035,13 +1069,28 @@ public void onFailure(Exception e) {
if (e == null) {
return;
}
Throwable cause = ExceptionsHelper.unwrapCause(e);
// in case of connection issue or the other node has no multi-entity
// transport actions (e.g., blue green deployment)
if (hasConnectionIssue(cause) || cause instanceof ActionNotFoundTransportException) {
handleConnectionException(nodeId);
try {
LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e);

handlePredictionFailure(e, adID, nodeId, failure);

} catch (Exception ex) {
LOG.error("Unexpected exception: {} for {}", ex, adID);
} finally {
if (nodeCount == responseCount.incrementAndGet()) {
handleEntityResponses();
}
}
}

private void handleEntityResponses() {
if (failure.get() != null) {
listener.onFailure(failure.get());
} else if (ackResponses.isEmpty()) {
listener.onFailure(new InternalFailure(adID, NO_ACK_ERR));
} else {
listener.onResponse(new AnomalyResultResponse(0, 0, 0, new ArrayList<FeatureData>()));
}
LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ public void setupTestNodes(Settings settings) {
}

public void tearDownTestNodes() {
if (testNodes == null) {
return;
}
for (FakeNode testNode : testNodes) {
testNode.close();
}
Expand All @@ -238,7 +241,7 @@ public void assertException(
Class<? extends Exception> exceptionType,
String msg
) {
Exception e = expectThrows(exceptionType, () -> listener.actionGet());
Exception e = expectThrows(exceptionType, () -> listener.actionGet(20_000));
assertThat(e.getMessage(), containsString(msg));
}

Expand Down
Loading

0 comments on commit 2f18231

Please sign in to comment.