Skip to content

Commit

Permalink
Handle nested and aliased fields correctly when copying mapping. (#50918
Browse files Browse the repository at this point in the history
)
  • Loading branch information
przemekwitek authored Jan 14, 2020
1 parent 30c3b34 commit 3c6f649
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final String NESTED_FIELD = "outer-field.inner-field";
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
Expand Down Expand Up @@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);

}

public void testDependentVariableCardinalityTooHighError() throws Exception {
Expand Down Expand Up @@ -343,6 +345,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
assertProgress(jobId, 100, 100, 100, 100);
}

public void testDependentVariableIsNested() throws Exception {
initialize("dependent_variable_is_nested");
String predictedClassField = NESTED_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, NESTED_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(NESTED_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testDependentVariableIsAliasToKeyword() throws Exception {
initialize("dependent_variable_is_alias");
String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testDependentVariableIsAliasToNested() throws Exception {
initialize("dependent_variable_is_alias_to_nested");
String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction";
indexData(sourceIndex, 100, 0, NESTED_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_NESTED_FIELD));
registerAnalytics(config);
putAnalytics(config);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
String dependentVariable = KEYWORD_FIELD;
Expand Down Expand Up @@ -434,7 +493,10 @@ private static void createIndex(String index) {
BOOLEAN_FIELD, "type=boolean",
NUMERICAL_FIELD, "type=double",
DISCRETE_NUMERICAL_FIELD, "type=integer",
KEYWORD_FIELD, "type=keyword")
KEYWORD_FIELD, "type=keyword",
NESTED_FIELD, "type=keyword",
ALIAS_TO_KEYWORD_FIELD, "type=alias,path=" + KEYWORD_FIELD,
ALIAS_TO_NESTED_FIELD, "type=alias,path=" + NESTED_FIELD)
.get();
}

Expand All @@ -446,7 +508,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
Expand All @@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (NESTED_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
Expand All @@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
}

/**
* Wrapper around extractValue with implicit casting to the appropriate type.
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(doc, path);
return (T)extractValue(String.join(".", path), doc);
}

private static <T> void assertTopClasses(Map<String, Object> resultsObject,
Expand Down Expand Up @@ -582,8 +650,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
.mappings()
.get(destIndex)
.sourceAsMap();
assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSortConfig;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
Expand All @@ -38,6 +40,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;

/**
Expand Down Expand Up @@ -155,21 +158,36 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse,
return maxValue;
}

@SuppressWarnings("unchecked")
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
Map<String, Object> properties = new HashMap<>();
properties.put(ID_COPY, Map.of("type", "keyword"));
properties.put(ID_COPY, Map.of("type", KeywordFieldMapper.CONTENT_TYPE));
for (Map.Entry<String, String> entry
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
String destFieldPath = entry.getKey();
String sourceFieldPath = entry.getValue();
Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath);
if (sourceFieldMapping != null) {
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
if (sourceFieldMapping instanceof Map) {
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
// If the source field is an alias, fetch the concrete field that the alias points to.
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
sourceFieldMapping = extractMapping(path, mappingsProperties);
}
}
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again.
if (sourceFieldMapping instanceof Map) {
properties.put(destFieldPath, sourceFieldMapping);
}
}
return properties;
}

private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
}

private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
Map<String, Object> metadata = new HashMap<>();
metadata.put(CREATION_DATE_MILLIS, clock.millis());
Expand Down Expand Up @@ -227,4 +245,3 @@ private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalytics
}
}
}

Loading

0 comments on commit 3c6f649

Please sign in to comment.