diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 32d8ad0b79845..d1e49169d75e7 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String NUMERICAL_FIELD = "numerical-field"; private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field"; private static final String KEYWORD_FIELD = "keyword-field"; + private static final String NESTED_FIELD = "outer-field.inner-field"; + private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field"; + private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field"; private static final List BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true)); private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0)); private static final List DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20)); @@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception { assertInferenceModelPersisted(jobId); assertMlResultsFieldMappings(predictedClassField, "keyword"); assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); - } public void testDependentVariableCardinalityTooHighError() throws Exception { @@ -343,6 +345,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang assertProgress(jobId, 100, 100, 100, 100); } + public void testDependentVariableIsNested() throws Exception { + initialize("dependent_variable_is_nested"); + String predictedClassField = NESTED_FIELD + "_prediction"; + indexData(sourceIndex, 100, 0, NESTED_FIELD); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(NESTED_FIELD)); + registerAnalytics(config); + putAnalytics(config); + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + } + + public void testDependentVariableIsAliasToKeyword() throws Exception { + initialize("dependent_variable_is_alias"); + String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction"; + indexData(sourceIndex, 100, 0, KEYWORD_FIELD); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_KEYWORD_FIELD)); + registerAnalytics(config); + putAnalytics(config); + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + } + + public void testDependentVariableIsAliasToNested() throws Exception { + initialize("dependent_variable_is_alias_to_nested"); + String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction"; + indexData(sourceIndex, 100, 0, NESTED_FIELD); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_NESTED_FIELD)); + registerAnalytics(config); + putAnalytics(config); + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + } + public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source"; String dependentVariable = KEYWORD_FIELD; @@ -434,7 +493,10 @@ private static void createIndex(String index) { BOOLEAN_FIELD, "type=boolean", NUMERICAL_FIELD, "type=double", DISCRETE_NUMERICAL_FIELD, "type=integer", - KEYWORD_FIELD, "type=keyword") + KEYWORD_FIELD, "type=keyword", + NESTED_FIELD, "type=keyword", + ALIAS_TO_KEYWORD_FIELD, "type=alias,path=" + KEYWORD_FIELD, + ALIAS_TO_NESTED_FIELD, "type=alias,path=" + NESTED_FIELD) .get(); } @@ -446,7 +508,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()), NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()), DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()), - KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())); + KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()), + NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())); IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } @@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo if (KEYWORD_FIELD.equals(dependentVariable) == false) { source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); } + if (NESTED_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); + } IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } @@ -487,10 +553,12 @@ private static Map getDestDoc(DataFrameAnalyticsConfig config, S } /** - * Wrapper around extractValue with implicit casting to the appropriate type. + * Wrapper around extractValue that: + * - allows dots (".") in the path elements provided as arguments + * - supports implicit casting to the appropriate type */ private static T getFieldValue(Map doc, String... path) { - return (T)extractValue(doc, path); + return (T)extractValue(String.join(".", path), doc); } private static void assertTopClasses(Map resultsObject, @@ -582,8 +650,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp .mappings() .get(destIndex) .sourceAsMap(); - assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType)); assertThat( + mappings.toString(), + getFieldValue( + mappings, + "properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"), + equalTo(expectedType)); + assertThat( + mappings.toString(), getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"), equalTo(expectedType)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java index 47ecb0ec2b6b8..ec3e192ab4ceb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java @@ -24,6 +24,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSortConfig; +import org.elasticsearch.index.mapper.FieldAliasMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -38,6 +40,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; /** @@ -155,21 +158,36 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse, return maxValue; } + @SuppressWarnings("unchecked") private static Map createAdditionalMappings(DataFrameAnalyticsConfig config, Map mappingsProperties) { Map properties = new HashMap<>(); - properties.put(ID_COPY, Map.of("type", "keyword")); + properties.put(ID_COPY, Map.of("type", KeywordFieldMapper.CONTENT_TYPE)); for (Map.Entry entry : config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) { String destFieldPath = entry.getKey(); String sourceFieldPath = entry.getValue(); - Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath); - if (sourceFieldMapping != null) { + Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties); + if (sourceFieldMapping instanceof Map) { + Map sourceFieldMappingAsMap = (Map) sourceFieldMapping; + // If the source field is an alias, fetch the concrete field that the alias points to. + if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) { + String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH); + sourceFieldMapping = extractMapping(path, mappingsProperties); + } + } + // We may have updated the value of {@code sourceFieldMapping} in the "if" block above. + // Hence, we need to check the "instanceof" condition again. + if (sourceFieldMapping instanceof Map) { properties.put(destFieldPath, sourceFieldMapping); } } return properties; } + private static Object extractMapping(String path, Map mappingsProperties) { + return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties); + } + private static Map createMetaData(String analyticsId, Clock clock) { Map metadata = new HashMap<>(); metadata.put(CREATION_DATE_MILLIS, clock.millis()); @@ -227,4 +245,3 @@ private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalytics } } } - diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java index 7cd00f68e4b8f..a4a9bfffab818 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java @@ -71,7 +71,11 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase { private static final String ANALYTICS_ID = "some-analytics-id"; private static final String[] SOURCE_INDEX = new String[] {"source-index"}; private static final String DEST_INDEX = "dest-index"; - private static final String DEPENDENT_VARIABLE = "dep_var"; + private static final String NUMERICAL_FIELD = "numerical-field"; + private static final String OUTER_FIELD = "outer-field"; + private static final String INNER_FIELD = "inner-field"; + private static final String ALIAS_TO_NUMERICAL_FIELD = "alias-to-numerical-field"; + private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field"; private static final int CURRENT_TIME_MILLIS = 123456789; private static final String CREATED_BY = "data-frame-analytics"; @@ -116,17 +120,18 @@ private Map testCreateDestinationIndex(DataFrameAnalysis analysi doAnswer(callListenerOnResponse(getSettingsResponse)) .when(client).execute(eq(GetSettingsAction.INSTANCE), getSettingsRequestCaptor.capture(), any()); - Map index1Mappings = + Map indexMappings = Map.of( "properties", - Map.of("field_1", "field_1_mappings", "field_2", "field_2_mappings", DEPENDENT_VARIABLE, Map.of("type", "integer"))); - MappingMetaData index1MappingMetaData = new MappingMetaData("_doc", index1Mappings); - - Map index2Mappings = - Map.of( - "properties", - Map.of("field_1", "field_1_mappings", "field_2", "field_2_mappings", DEPENDENT_VARIABLE, Map.of("type", "integer"))); - MappingMetaData index2MappingMetaData = new MappingMetaData("_doc", index2Mappings); + Map.of( + "field_1", "field_1_mappings", + "field_2", "field_2_mappings", + NUMERICAL_FIELD, Map.of("type", "integer"), + OUTER_FIELD, Map.of("properties", Map.of(INNER_FIELD, Map.of("type", "integer"))), + ALIAS_TO_NUMERICAL_FIELD, Map.of("type", "alias", "path", NUMERICAL_FIELD), + ALIAS_TO_NESTED_FIELD, Map.of("type", "alias", "path", "outer-field.inner-field"))); + MappingMetaData index1MappingMetaData = new MappingMetaData("_doc", indexMappings); + MappingMetaData index2MappingMetaData = new MappingMetaData("_doc", indexMappings); ImmutableOpenMap.Builder mappings = ImmutableOpenMap.builder(); mappings.put("index_1", index1MappingMetaData); @@ -143,7 +148,9 @@ private Map testCreateDestinationIndex(DataFrameAnalysis analysi config, ActionListener.wrap( response -> {}, - e -> fail(e.getMessage()))); + e -> fail(e.getMessage()) + ) + ); GetSettingsRequest capturedGetSettingsRequest = getSettingsRequestCaptor.getValue(); assertThat(capturedGetSettingsRequest.indices(), equalTo(SOURCE_INDEX)); @@ -166,6 +173,10 @@ private Map testCreateDestinationIndex(DataFrameAnalysis analysi assertThat(extractValue("_doc.properties.ml__id_copy.type", map), equalTo("keyword")); assertThat(extractValue("_doc.properties.field_1", map), equalTo("field_1_mappings")); assertThat(extractValue("_doc.properties.field_2", map), equalTo("field_2_mappings")); + assertThat(extractValue("_doc.properties.numerical-field.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.outer-field.properties.inner-field.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.alias-to-numerical-field.type", map), equalTo("alias")); + assertThat(extractValue("_doc.properties.alias-to-nested-field.type", map), equalTo("alias")); assertThat(extractValue("_doc._meta.analytics", map), equalTo(ANALYTICS_ID)); assertThat(extractValue("_doc._meta.creation_date_in_millis", map), equalTo(CURRENT_TIME_MILLIS)); assertThat(extractValue("_doc._meta.created_by", map), equalTo(CREATED_BY)); @@ -178,13 +189,31 @@ public void testCreateDestinationIndex_OutlierDetection() throws IOException { } public void testCreateDestinationIndex_Regression() throws IOException { - Map map = testCreateDestinationIndex(new Regression(DEPENDENT_VARIABLE)); - assertThat(extractValue("_doc.properties.ml.dep_var_prediction.type", map), equalTo("integer")); + Map map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD)); + assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer")); } public void testCreateDestinationIndex_Classification() throws IOException { - Map map = testCreateDestinationIndex(new Classification(DEPENDENT_VARIABLE)); - assertThat(extractValue("_doc.properties.ml.dep_var_prediction.type", map), equalTo("integer")); + Map map = testCreateDestinationIndex(new Classification(NUMERICAL_FIELD)); + assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testCreateDestinationIndex_Classification_DependentVariableIsNested() throws IOException { + Map map = testCreateDestinationIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD)); + assertThat(extractValue("_doc.properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testCreateDestinationIndex_Classification_DependentVariableIsAlias() throws IOException { + Map map = testCreateDestinationIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD)); + assertThat(extractValue("_doc.properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testCreateDestinationIndex_Classification_DependentVariableIsAliasToNested() throws IOException { + Map map = testCreateDestinationIndex(new Classification(ALIAS_TO_NESTED_FIELD)); + assertThat(extractValue("_doc.properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer")); assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); } @@ -213,10 +242,15 @@ public void testCreateDestinationIndex_ResultsFieldsExistsInSourceIndex() { ); } - private Map testUpdateMappingsToDestIndex(DataFrameAnalysis analysis, - Map properties) throws IOException { + private Map testUpdateMappingsToDestIndex(DataFrameAnalysis analysis) throws IOException { DataFrameAnalyticsConfig config = createConfig(analysis); + Map properties = Map.of( + NUMERICAL_FIELD, Map.of("type", "integer"), + OUTER_FIELD, Map.of("properties", Map.of(INNER_FIELD, Map.of("type", "integer"))), + ALIAS_TO_NUMERICAL_FIELD, Map.of("type", "alias", "path", NUMERICAL_FIELD), + ALIAS_TO_NESTED_FIELD, Map.of("type", "alias", "path", OUTER_FIELD + "." + INNER_FIELD) + ); ImmutableOpenMap.Builder mappings = ImmutableOpenMap.builder(); mappings.put("", new MappingMetaData("_doc", Map.of("properties", properties))); GetIndexResponse getIndexResponse = @@ -252,19 +286,35 @@ private Map testUpdateMappingsToDestIndex(DataFrameAnalysis anal } public void testUpdateMappingsToDestIndex_OutlierDetection() throws IOException { - testUpdateMappingsToDestIndex(new OutlierDetection.Builder().build(), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); + testUpdateMappingsToDestIndex(new OutlierDetection.Builder().build()); } public void testUpdateMappingsToDestIndex_Regression() throws IOException { - Map map = - testUpdateMappingsToDestIndex(new Regression(DEPENDENT_VARIABLE), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); - assertThat(extractValue("properties.ml.dep_var_prediction.type", map), equalTo("integer")); + Map map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD)); + assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer")); } public void testUpdateMappingsToDestIndex_Classification() throws IOException { - Map map = - testUpdateMappingsToDestIndex(new Classification(DEPENDENT_VARIABLE), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); - assertThat(extractValue("properties.ml.dep_var_prediction.type", map), equalTo("integer")); + Map map = testUpdateMappingsToDestIndex(new Classification(NUMERICAL_FIELD)); + assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsNested() throws IOException { + Map map = testUpdateMappingsToDestIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD)); + assertThat(extractValue("properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAlias() throws IOException { + Map map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD)); + assertThat(extractValue("properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAliasToNested() throws IOException { + Map map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NESTED_FIELD)); + assertThat(extractValue("properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer")); assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); }