Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Fix issue where max number of multi-entity detector doesn't work for UpdateDetector #285

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX;
import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.XCONTENT_WITH_TYPE;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -245,31 +247,52 @@ private void updateAnomalyDetector(String detectorId) {
);
}

private void onGetAnomalyDetectorResponse(GetResponse response) throws IOException {
private void onGetAnomalyDetectorResponse(GetResponse response) {
if (!response.isExists()) {
listener
.onFailure(new ElasticsearchStatusException("AnomalyDetector is not found with id: " + detectorId, RestStatus.NOT_FOUND));
return;
}
try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion());
if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) {
validateAgainstExistingMultiEntityAnomalyDetector(detectorId);
} else {
validateCategoricalField(detectorId);
}
} catch (IOException e) {
String message = "Failed to parse anomaly detector " + detectorId;
logger.error(message, e);
listener.onFailure(new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR));
}

validateCategoricalField(detectorId);
}

private boolean hasCategoryField(AnomalyDetector detector) {
return detector.getCategoryField() != null && !detector.getCategoryField().isEmpty();
}

private void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId) {
QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);

SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener
.wrap(response -> onSearchMultiEntityAdResponse(response, detectorId), exception -> listener.onFailure(exception))
);
}

private void createAnomalyDetector() {
try {
List<String> categoricalFields = anomalyDetector.getCategoryField();
if (categoricalFields != null && categoricalFields.size() > 0) {
QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);

SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener.wrap(response -> onSearchMultiEntityAdResponse(response), exception -> listener.onFailure(exception))
);
validateAgainstExistingMultiEntityAnomalyDetector(null);
} else {
QueryBuilder query = QueryBuilders.matchAllQuery();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout);
Expand Down Expand Up @@ -298,13 +321,13 @@ private void onSearchSingleEntityAdResponse(SearchResponse response) throws IOEx
}
}

private void onSearchMultiEntityAdResponse(SearchResponse response) throws IOException {
private void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId) throws IOException {
if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) {
String errorMsg = EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG + maxMultiEntityAnomalyDetectors;
logger.error(errorMsg);
listener.onFailure(new IllegalArgumentException(errorMsg));
} else {
validateCategoricalField(null);
validateCategoricalField(detectorId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

@SuppressWarnings("unchecked")
private void testValidTypeTepmlate(String filedTypeName) throws IOException {
private void testValidTypeTemplate(String filedTypeName) throws IOException {
String field = "a";
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));

Expand Down Expand Up @@ -388,24 +388,24 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

public void testIpField() throws IOException {
testValidTypeTepmlate(CommonName.IP_TYPE);
testValidTypeTemplate(CommonName.IP_TYPE);
}

public void testKeywordField() throws IOException {
testValidTypeTepmlate(CommonName.KEYWORD_TYPE);
testValidTypeTemplate(CommonName.KEYWORD_TYPE);
}

@SuppressWarnings("unchecked")
private void testUpdateTepmlate(String fieldTypeName) throws IOException {
private void testUpdateTemplate(String fieldTypeName) throws IOException {
String field = "a";
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field));

SearchResponse detectorResponse = mock(SearchResponse.class);
int totalHits = 9;
when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits));

GetResponse getDetectorResponse = mock(GetResponse.class);
when(getDetectorResponse.isExists()).thenReturn(true);
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse userIndexResponse = mock(SearchResponse.class);
int userIndexHits = 0;
Expand Down Expand Up @@ -485,15 +485,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE
}

public void testUpdateIpField() throws IOException {
testUpdateTepmlate(CommonName.IP_TYPE);
testUpdateTemplate(CommonName.IP_TYPE);
}

public void testUpdateKeywordField() throws IOException {
testUpdateTepmlate(CommonName.KEYWORD_TYPE);
testUpdateTemplate(CommonName.KEYWORD_TYPE);
}

public void testUpdateTextField() throws IOException {
testUpdateTepmlate(TEXT_FIELD_TYPE);
testUpdateTemplate(TEXT_FIELD_TYPE);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -527,4 +527,151 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException {
assertTrue(value instanceof IllegalArgumentException);
assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG));
}

@SuppressWarnings("unchecked")
public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException {
int totalHits = 10;
AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null);
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(existingDetector, existingDetector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits));

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof SearchRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) args[1];

listener.onResponse(searchResponse);

return null;
}).when(clientMock).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];

listener.onResponse(getDetectorResponse);

return null;
}).when(clientMock).get(any(GetRequest.class), any());

ClusterName clusterName = new ClusterName("test");
ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build();
when(clusterService.state()).thenReturn(clusterState);

handler = new IndexAnomalyDetectorActionHandler(
clusterService,
clientMock,
channel,
anomalyDetectionIndices,
detectorId,
seqNo,
primaryTerm,
refreshPolicy,
detector,
requestTimeout,
maxSingleEntityAnomalyDetectors,
maxMultiEntityAnomalyDetectors,
maxAnomalyFeatures,
RestRequest.Method.PUT,
xContentRegistry(),
mock(RestClient.class),
null
);

handler.resolveUserAndStart();

ArgumentCaptor<Exception> response = ArgumentCaptor.forClass(Exception.class);
verify(clientMock, times(1)).search(any(SearchRequest.class), any());
verify(clientMock, times(1)).get(any(GetRequest.class), any());
verify(channel).onFailure(response.capture());
Exception value = response.getValue();
assertTrue(value instanceof IllegalArgumentException);
assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG));
}

@SuppressWarnings("unchecked")
public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException {
int totalHits = 10;
AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a"));
GetResponse getDetectorResponse = TestHelpers
.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits));

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof SearchRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) args[1];

listener.onResponse(searchResponse);

return null;
}).when(clientMock).search(any(SearchRequest.class), any());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();
assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2);

assertTrue(args[0] instanceof GetRequest);
assertTrue(args[1] instanceof ActionListener);

ActionListener<GetResponse> listener = (ActionListener<GetResponse>) args[1];

listener.onResponse(getDetectorResponse);

return null;
}).when(clientMock).get(any(GetRequest.class), any());

ClusterName clusterName = new ClusterName("test");
ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build();
when(clusterService.state()).thenReturn(clusterState);

handler = new IndexAnomalyDetectorActionHandler(
clusterService,
clientMock,
channel,
anomalyDetectionIndices,
detectorId,
seqNo,
primaryTerm,
refreshPolicy,
detector,
requestTimeout,
maxSingleEntityAnomalyDetectors,
maxMultiEntityAnomalyDetectors,
maxAnomalyFeatures,
RestRequest.Method.PUT,
xContentRegistry(),
mock(RestClient.class),
null
);

handler.resolveUserAndStart();

ArgumentCaptor<Exception> response = ArgumentCaptor.forClass(Exception.class);
verify(clientMock, times(0)).search(any(SearchRequest.class), any());
verify(clientMock, times(1)).get(any(GetRequest.class), any());
verify(channel).onFailure(response.capture());
Exception value = response.getValue();
// make sure execution passes all necessary checks
assertTrue(value instanceof IllegalStateException);
assertTrue(value.getMessage().contains("NodeClient has not been initialized"));
}
}