diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index 5e91e2b4..bd563fb3 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -17,6 +17,7 @@ import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static com.amazon.opendistroforelasticsearch.ad.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; @@ -53,6 +54,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -245,31 +247,52 @@ private void updateAnomalyDetector(String detectorId) { ); } - private void onGetAnomalyDetectorResponse(GetResponse response) throws IOException { + private void onGetAnomalyDetectorResponse(GetResponse response) { if (!response.isExists()) { listener .onFailure(new ElasticsearchStatusException("AnomalyDetector is not found with id: " + detectorId, RestStatus.NOT_FOUND)); return; } + try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); + if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) { + validateAgainstExistingMultiEntityAnomalyDetector(detectorId); + } else { + validateCategoricalField(detectorId); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector " + detectorId; + logger.error(message, e); + listener.onFailure(new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } - validateCategoricalField(detectorId); + } + + private boolean hasCategoryField(AnomalyDetector detector) { + return detector.getCategoryField() != null && !detector.getCategoryField().isEmpty(); + } + + private void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap(response -> onSearchMultiEntityAdResponse(response, detectorId), exception -> listener.onFailure(exception)) + ); } private void createAnomalyDetector() { try { List categoricalFields = anomalyDetector.getCategoryField(); if (categoricalFields != null && categoricalFields.size() > 0) { - QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); - - client - .search( - searchRequest, - ActionListener.wrap(response -> onSearchMultiEntityAdResponse(response), exception -> listener.onFailure(exception)) - ); + validateAgainstExistingMultiEntityAnomalyDetector(null); } else { QueryBuilder query = QueryBuilders.matchAllQuery(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); @@ -298,13 +321,13 @@ private void onSearchSingleEntityAdResponse(SearchResponse response) throws IOEx } } - private void onSearchMultiEntityAdResponse(SearchResponse response) throws IOException { + private void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId) throws IOException { if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { String errorMsg = EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG + maxMultiEntityAnomalyDetectors; logger.error(errorMsg); listener.onFailure(new IllegalArgumentException(errorMsg)); } else { - validateCategoricalField(null); + validateCategoricalField(detectorId); } } diff --git a/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index 7ec8e62b..b0870e1e 100644 --- a/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/elasticsearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -311,7 +311,7 @@ public void doE } @SuppressWarnings("unchecked") - private void testValidTypeTepmlate(String filedTypeName) throws IOException { + private void testValidTypeTemplate(String filedTypeName) throws IOException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -388,15 +388,15 @@ public void doE } public void testIpField() throws IOException { - testValidTypeTepmlate(CommonName.IP_TYPE); + testValidTypeTemplate(CommonName.IP_TYPE); } public void testKeywordField() throws IOException { - testValidTypeTepmlate(CommonName.KEYWORD_TYPE); + testValidTypeTemplate(CommonName.KEYWORD_TYPE); } @SuppressWarnings("unchecked") - private void testUpdateTepmlate(String fieldTypeName) throws IOException { + private void testUpdateTemplate(String fieldTypeName) throws IOException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -404,8 +404,8 @@ private void testUpdateTepmlate(String fieldTypeName) throws IOException { int totalHits = 9; when(detectorResponse.getHits()).thenReturn(createSearchHits(totalHits)); - GetResponse getDetectorResponse = mock(GetResponse.class); - when(getDetectorResponse.isExists()).thenReturn(true); + GetResponse getDetectorResponse = TestHelpers + .createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX); SearchResponse userIndexResponse = mock(SearchResponse.class); int userIndexHits = 0; @@ -485,15 +485,15 @@ public void doE } public void testUpdateIpField() throws IOException { - testUpdateTepmlate(CommonName.IP_TYPE); + testUpdateTemplate(CommonName.IP_TYPE); } public void testUpdateKeywordField() throws IOException { - testUpdateTepmlate(CommonName.KEYWORD_TYPE); + testUpdateTemplate(CommonName.KEYWORD_TYPE); } public void testUpdateTextField() throws IOException { - testUpdateTepmlate(TEXT_FIELD_TYPE); + testUpdateTemplate(TEXT_FIELD_TYPE); } @SuppressWarnings("unchecked") @@ -527,4 +527,151 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { assertTrue(value instanceof IllegalArgumentException); assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); } + + @SuppressWarnings("unchecked") + public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException { + int totalHits = 10; + AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null); + GetResponse getDetectorResponse = TestHelpers + .createGetResponse(existingDetector, existingDetector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(searchResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(getDetectorResponse); + + return null; + }).when(clientMock).get(any(GetRequest.class), any()); + + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry(), + mock(RestClient.class), + null + ); + + handler.resolveUserAndStart(); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, times(1)).search(any(SearchRequest.class), any()); + verify(clientMock, times(1)).get(any(GetRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + } + + @SuppressWarnings("unchecked") + public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException { + int totalHits = 10; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); + GetResponse getDetectorResponse = TestHelpers + .createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(createSearchHits(totalHits)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(searchResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue(String.format("The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), args.length == 2); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(getDetectorResponse); + + return null; + }).when(clientMock).get(any(GetRequest.class), any()); + + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry(), + mock(RestClient.class), + null + ); + + handler.resolveUserAndStart(); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, times(0)).search(any(SearchRequest.class), any()); + verify(clientMock, times(1)).get(any(GetRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + // make sure execution passes all necessary checks + assertTrue(value instanceof IllegalStateException); + assertTrue(value.getMessage().contains("NodeClient has not been initialized")); + } }