Skip to content

Commit

Permalink
[ML] Add bucket_script agg support to data frames (elastic#41594) (el…
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored Apr 29, 2019
1 parent a01f451 commit 92a820b
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;

import java.io.IOException;
Expand Down Expand Up @@ -66,6 +67,10 @@ public Collection<AggregationBuilder> getAggregatorFactories() {
return aggregations.getAggregatorFactories();
}

public Collection<PipelineAggregationBuilder> getPipelineAggregatorFactories() {
return aggregations.getPipelineAggregatorFactories();
}

public static AggregationConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException {
NamedXContentRegistry registry = parser.getXContentRegistry();
Map<String, Object> source = parser.mapOrdered();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,57 @@ public void testPivotWithScriptedMetricAgg() throws Exception {
assertEquals(711.0, actual.doubleValue(), 0.000001);
}

public void testPivotWithBucketScriptAgg() throws Exception {
String transformId = "bucketScriptPivot";
String dataFrameIndex = "bucket_script_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"avg_rating_again\": {"
+ " \"bucket_script\": {"
+ " \"buckets_path\": {\"param_1\": \"avg_rating\"},"
+ " \"script\": \"return params.param_1\""
+ " } }"
+ " } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(dataFrameIndex));

// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));

// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating_again", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
Expand All @@ -21,7 +22,9 @@

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;
Expand All @@ -42,6 +45,7 @@ final class AggregationResultUtils {
public static Stream<Map<String, Object>> extractCompositeAggregationResults(CompositeAggregation agg,
GroupConfig groups,
Collection<AggregationBuilder> aggregationBuilders,
Collection<PipelineAggregationBuilder> pipelineAggs,
Map<String, String> fieldTypeMap,
DataFrameIndexerTransformStats stats) {
return agg.getBuckets().stream().map(bucket -> {
Expand All @@ -58,18 +62,21 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
document.put(destinationFieldName, value);
});

for (AggregationBuilder aggregationBuilder : aggregationBuilders) {
String aggName = aggregationBuilder.getName();
List<String> aggNames = aggregationBuilders.stream().map(AggregationBuilder::getName).collect(Collectors.toList());
aggNames.addAll(pipelineAggs.stream().map(PipelineAggregationBuilder::getName).collect(Collectors.toList()));

for (String aggName: aggNames) {
final String fieldType = fieldTypeMap.get(aggName);

// TODO: support other aggregation types
Aggregation aggResult = bucket.getAggregations().get(aggName);

if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
// If the type is numeric, simply gather the `value` type, otherwise utilize `getValueAsString` so we don't lose
// formatted outputs.
if (isNumericType(fieldType)) {
// If the type is numeric or if the formatted string is the same as simply making the value a string,
// gather the `value` type, otherwise utilize `getValueAsString` so we don't lose formatted outputs.
if (isNumericType(fieldType) ||
(aggResultSingleValue.getValueAsString().equals(String.valueOf(aggResultSingleValue.value())))) {
document.put(aggName, aggResultSingleValue.value());
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ enum AggregationType {
MAX("max", SOURCE),
MIN("min", SOURCE),
SUM("sum", SOURCE),
SCRIPTED_METRIC("scripted_metric", DYNAMIC);
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
BUCKET_SCRIPT("bucket_script", DYNAMIC);

private final String aggregationType;
private final String targetMapping;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -102,10 +103,12 @@ public Stream<Map<String, Object>> extractResults(CompositeAggregation agg,

GroupConfig groups = config.getGroupConfig();
Collection<AggregationBuilder> aggregationBuilders = config.getAggregationConfig().getAggregatorFactories();
Collection<PipelineAggregationBuilder> pipelineAggregationBuilders = config.getAggregationConfig().getPipelineAggregatorFactories();

return AggregationResultUtils.extractCompositeAggregationResults(agg,
groups,
aggregationBuilders,
pipelineAggregationBuilders,
fieldTypeMap,
dataFrameIndexerTransformStats);
}
Expand Down Expand Up @@ -148,6 +151,7 @@ private static CompositeAggregationBuilder createCompositeAggregation(PivotConfi
LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput());
compositeAggregation = CompositeAggregationBuilder.parse(COMPOSITE_AGGREGATION_NAME, parser);
config.getAggregationConfig().getAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg));
config.getAggregationConfig().getPipelineAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg));
} catch (IOException e) {
throw new RuntimeException(DataFrameMessages.DATA_FRAME_TRANSFORM_PIVOT_FAILED_TO_CREATE_COMPOSITE_AGGREGATION, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
Expand Down Expand Up @@ -85,6 +86,12 @@ public static void deduceMappings(final Client client,
}
}

// For pipeline aggs, since they are referencing other aggregations in the payload, they do not have any
// sourcefieldnames to put into the payload. Though, certain ones, i.e. avg_bucket, do have determinant value types
for (PipelineAggregationBuilder agg : config.getAggregationConfig().getPipelineAggregatorFactories()) {
aggregationTypes.put(agg.getName(), agg.getType());
}

Map<String, String> allFieldNames = new HashMap<>();
allFieldNames.putAll(aggregationSourceFieldNames);
allFieldNames.putAll(fieldNamesForGrouping);
Expand Down
Loading

0 comments on commit 92a820b

Please sign in to comment.