Skip to content

Commit

Permalink
Impl stddev and variance function in SQL and PPL (#115)
Browse files Browse the repository at this point in the history
* impl variance frontend and backend

* Support construct AggregationResponseParser during Aggregator build stage

* add var and varp for PPL

Signed-off-by: penghuo <penghuo@gmail.com>

* add UT

Signed-off-by: penghuo <penghuo@gmail.com>

* fix UT

Signed-off-by: penghuo <penghuo@gmail.com>

* fix doc format

Signed-off-by: penghuo <penghuo@gmail.com>

* fix doc format

Signed-off-by: penghuo <penghuo@gmail.com>

* fix the doc

Signed-off-by: penghuo <penghuo@gmail.com>

* add stddev_samp and stddev_pop

Signed-off-by: penghuo <penghuo@gmail.com>

* fix UT coverage

* address comments

Signed-off-by: penghuo <penghuo@gmail.com>
  • Loading branch information
penghuo authored Jun 11, 2021
1 parent b39e7b6 commit 8632f80
Show file tree
Hide file tree
Showing 24 changed files with 1,414 additions and 8 deletions.
1 change: 1 addition & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dependencies {
compile group: 'org.springframework', name: 'spring-beans', version: '5.2.5.RELEASE'
compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
compile group: 'com.facebook.presto', name: 'presto-matching', version: '0.240'
compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
compile project(':common')

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ public Expression visitNot(Not node, AnalysisContext context) {

@Override
public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) {
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
Optional<BuiltinFunctionName> builtinFunctionName =
BuiltinFunctionName.ofAggregation(node.getFuncName());
if (builtinFunctionName.isPresent()) {
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,22 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}

public Aggregator varPop(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARPOP, expressions);
}

public Aggregator stddevSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions);
}

public Aggregator stddevPop(Expression... expressions) {
return aggregate(BuiltinFunctionName.STDDEV_POP, expressions);
}

public RankingWindowFunction rowNumber() {
return (RankingWindowFunction) repository.compile(
BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation;
import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample;
import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation;
import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample;

import com.google.common.collect.ImmutableMap;
import java.util.Collections;
Expand Down Expand Up @@ -68,6 +72,10 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(count());
repository.register(min());
repository.register(max());
repository.register(varSamp());
repository.register(varPop());
repository.register(stddevSamp());
repository.register(stddevPop());
}

private static FunctionResolver avg() {
Expand Down Expand Up @@ -159,4 +167,48 @@ private static FunctionResolver max() {
.build()
);
}

private static FunctionResolver varSamp() {
FunctionName functionName = BuiltinFunctionName.VARSAMP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> varianceSample(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver varPop() {
FunctionName functionName = BuiltinFunctionName.VARPOP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> variancePopulation(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver stddevSamp() {
FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> stddevSample(arguments, DOUBLE))
.build()
);
}

private static FunctionResolver stddevPop() {
FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)),
arguments -> stddevPopulation(arguments, DOUBLE))
.build()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package org.opensearch.sql.expression.aggregation;

import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue;
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* StandardDeviation Aggregator.
*/
public class StdDevAggregator extends Aggregator<StdDevAggregator.StdDevState> {

private final boolean isSampleStdDev;

/**
* Build Population Variance {@link VarianceAggregator}.
*/
public static Aggregator stddevPopulation(List<Expression> arguments,
ExprCoreType returnType) {
return new StdDevAggregator(false, arguments, returnType);
}

/**
* Build Sample Variance {@link VarianceAggregator}.
*/
public static Aggregator stddevSample(List<Expression> arguments,
ExprCoreType returnType) {
return new StdDevAggregator(true, arguments, returnType);
}

/**
* VarianceAggregator constructor.
*
* @param isSampleStdDev true for sample standard deviation aggregator, false for population
* standard deviation aggregator.
* @param arguments aggregator arguments.
* @param returnType aggregator return types.
*/
public StdDevAggregator(
Boolean isSampleStdDev, List<Expression> arguments, ExprCoreType returnType) {
super(
isSampleStdDev
? BuiltinFunctionName.STDDEV_SAMP.getName()
: BuiltinFunctionName.STDDEV_POP.getName(),
arguments,
returnType);
this.isSampleStdDev = isSampleStdDev;
}

@Override
public StdDevAggregator.StdDevState create() {
return new StdDevAggregator.StdDevState(isSampleStdDev);
}

@Override
protected StdDevAggregator.StdDevState iterate(ExprValue value,
StdDevAggregator.StdDevState state) {
state.evaluate(value);
return state;
}

@Override
public String toString() {
return StringUtils.format(
"%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments()));
}

protected static class StdDevState implements AggregationState {

private final StandardDeviation standardDeviation;

private final List<Double> values = new ArrayList<>();

public StdDevState(boolean isSampleStdDev) {
this.standardDeviation = new StandardDeviation(isSampleStdDev);
}

public void evaluate(ExprValue value) {
values.add(value.doubleValue());
}

@Override
public ExprValue result() {
return values.size() == 0
? ExprNullValue.of()
: doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package org.opensearch.sql.expression.aggregation;

import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue;
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.moment.Variance;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* Variance Aggregator.
*/
public class VarianceAggregator extends Aggregator<VarianceAggregator.VarianceState> {

private final boolean isSampleVariance;

/**
* Build Population Variance {@link VarianceAggregator}.
*/
public static Aggregator variancePopulation(List<Expression> arguments,
ExprCoreType returnType) {
return new VarianceAggregator(false, arguments, returnType);
}

/**
* Build Sample Variance {@link VarianceAggregator}.
*/
public static Aggregator varianceSample(List<Expression> arguments,
ExprCoreType returnType) {
return new VarianceAggregator(true, arguments, returnType);
}

/**
* VarianceAggregator constructor.
*
* @param isSampleVariance true for sample variance aggregator, false for population variance
* aggregator.
* @param arguments aggregator arguments.
* @param returnType aggregator return types.
*/
public VarianceAggregator(
Boolean isSampleVariance, List<Expression> arguments, ExprCoreType returnType) {
super(
isSampleVariance
? BuiltinFunctionName.VARSAMP.getName()
: BuiltinFunctionName.VARPOP.getName(),
arguments,
returnType);
this.isSampleVariance = isSampleVariance;
}

@Override
public VarianceState create() {
return new VarianceState(isSampleVariance);
}

@Override
protected VarianceState iterate(ExprValue value, VarianceState state) {
state.evaluate(value);
return state;
}

@Override
public String toString() {
return StringUtils.format(
"%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments()));
}

protected static class VarianceState implements AggregationState {

private final Variance variance;

private final List<Double> values = new ArrayList<>();

public VarianceState(boolean isSampleVariance) {
this.variance = new Variance(isSampleVariance);
}

public void evaluate(ExprValue value) {
values.add(value.doubleValue());
}

@Override
public ExprValue result() {
return values.size() == 0
? ExprNullValue.of()
: doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.sql.expression.function;

import com.google.common.collect.ImmutableMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import lombok.Getter;
Expand Down Expand Up @@ -126,6 +127,14 @@ public enum BuiltinFunctionName {
COUNT(FunctionName.of("count")),
MIN(FunctionName.of("min")),
MAX(FunctionName.of("max")),
// sample variance
VARSAMP(FunctionName.of("var_samp")),
// population standard variance
VARPOP(FunctionName.of("var_pop")),
// sample standard deviation.
STDDEV_SAMP(FunctionName.of("stddev_samp")),
// population standard deviation.
STDDEV_POP(FunctionName.of("stddev_pop")),

/**
* Text Functions.
Expand Down Expand Up @@ -189,7 +198,28 @@ public enum BuiltinFunctionName {
ALL_NATIVE_FUNCTIONS = builder.build();
}

private static final Map<String, BuiltinFunctionName> AGGREGATION_FUNC_MAPPING =
new ImmutableMap.Builder<String, BuiltinFunctionName>()
.put("max", BuiltinFunctionName.MAX)
.put("min", BuiltinFunctionName.MIN)
.put("avg", BuiltinFunctionName.AVG)
.put("count", BuiltinFunctionName.COUNT)
.put("sum", BuiltinFunctionName.SUM)
.put("var_pop", BuiltinFunctionName.VARPOP)
.put("var_samp", BuiltinFunctionName.VARSAMP)
.put("variance", BuiltinFunctionName.VARPOP)
.put("std", BuiltinFunctionName.STDDEV_POP)
.put("stddev", BuiltinFunctionName.STDDEV_POP)
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
.build();

public static Optional<BuiltinFunctionName> of(String str) {
return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null));
}

public static Optional<BuiltinFunctionName> ofAggregation(String functionName) {
return Optional.ofNullable(
AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ public void aggregation_filter() {
);
}

@Test
public void variance_mapto_varPop() {
assertAnalyzeEqual(
dsl.varPop(DSL.ref("integer_value", INTEGER)),
AstDSL.aggregate("variance", qualifiedName("integer_value"))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Loading

0 comments on commit 8632f80

Please sign in to comment.