Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require that the dependent variable column has at most 2 distinct values in classfication analysis. #47858

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonMap(dependentVariable, 2L);
}

@Override
public boolean supportsMissingValues() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/
List<RequiredField> getRequiredFields();

/**
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
*/
Map<String, Long> getFieldCardinalityLimits();

/**
* @return {@code true} if this analysis supports data frame rows with missing values
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ public List<RequiredField> getRequiredFields() {
return Collections.emptyList();
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
}

@Override
public boolean supportsMissingValues() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
}

@Override
public boolean supportsMissingValues() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

public class ClassificationTests extends AbstractSerializingTestCase<Classification> {

Expand Down Expand Up @@ -65,4 +68,8 @@ public void testConstructor_GivenTrainingPercentIsGreaterThan100() {

assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {

Expand Down Expand Up @@ -82,6 +84,10 @@ public void testGetParams_GivenExplicitValues() {
assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}

public void testGetStateDocId() {
OutlierDetection outlierDetection = createRandom();
assertThat(outlierDetection.persistsState(), is(false));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

public class RegressionTests extends AbstractSerializingTestCase<Regression> {

Expand Down Expand Up @@ -66,6 +68,10 @@ public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}

public void testGetStateDocId() {
Regression regression = createRandom();
assertThat(regression.persistsState(), is(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.integration;

import com.google.common.collect.Ordering;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
Expand Down Expand Up @@ -37,10 +38,10 @@

public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));

private String jobId;
private String sourceIndex;
Expand All @@ -53,36 +54,9 @@ public void cleanup() throws Exception {

public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set");
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);

{ // Index 350 rows, 300 of them being training rows.
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.get();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 300; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
for (int i = 300; i < 350; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);

Expand All @@ -97,10 +71,10 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);

assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertThat(resultsObject.containsKey("top_classes"), is(false));
}

Expand All @@ -117,9 +91,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_100");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);

Expand All @@ -133,8 +107,8 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));

assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
assertThat(resultsObject.containsKey("top_classes"), is(false));
Expand All @@ -153,15 +127,15 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_50");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);

DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
registerAnalytics(config);
putAnalytics(config);

Expand All @@ -176,8 +150,8 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));

assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
Expand Down Expand Up @@ -205,7 +179,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
initialize("classification_top_classes_requested");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);

int numTopClasses = 2;
DataFrameAnalyticsConfig config =
Expand All @@ -214,7 +188,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
registerAnalytics(config);
putAnalytics(config);

Expand All @@ -229,8 +203,8 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);

assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertTopClasses(resultsObject, numTopClasses);
}

Expand All @@ -245,25 +219,47 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
"Finished analysis");
}

public void testDependentVariableCardinalityTooHighError() {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
assertThat(e.status().getStatus(), equalTo(400));
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
}

private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
this.destIndex = sourceIndex + "_results";
}

private static void indexTrainingData(String sourceIndex, int numRows) {
private static void indexData(String sourceIndex,
int numTrainingRows, int numNonTrainingRows,
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
.get();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numRows; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
for (int i = 0; i < numTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
.source(NUMERICAL_FIELD, numericalValue);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
Expand Down Expand Up @@ -302,10 +298,10 @@ private static void assertTopClasses(Map<String, Object> resultsObject, int numT
classNames.add((String) topClass.get("class_name"));
classProbabilities.add((Double) topClass.get("class_probability"));
}
// Assert that all the class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
// Assert that all the predicted class names come from the set of keyword field values.
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
// Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities.
Expand Down
Loading