Skip to content

Commit

Permalink
[opt](Nereids) polish aggregate function signature matching (#39352) (#…
Browse files Browse the repository at this point in the history
…39460)

pick from master #39352

use double to match string
- corr
- covar
- covar_samp
- stddev
- stddev_samp

use largeint to match string
- group_bit_and
- group_bit_or
- group_git_xor

use double to match decimalv3
- topn_weighted

optimize error message
- multi_distinct_sum
- multi_distinct_sum0
  • Loading branch information
morrySnow authored Aug 16, 2024
1 parent ec0e413 commit d56000e
Show file tree
Hide file tree
Showing 18 changed files with 153 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ public class AvgWeighted extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
public class BitmapAgg extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE)
);
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE)
);

public BitmapAgg(Expression arg0) {
super("bitmap_agg", arg0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ public CollectList(boolean distinct, Expression arg0, Expression arg1) {
super("collect_list", distinct, arg0, arg1);
}

@Override
public FunctionSignature computeSignature(FunctionSignature signature) {
signature = signature.withReturnType(ArrayType.of(getArgumentType(0)));
return super.computeSignature(signature);
}

/**
* withDistinctAndChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public class Corr extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public class Covar extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public class CovarSamp extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ public class GroupBitAnd extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitOr extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitXor extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,16 @@
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** MultiDistinctSum */
public class MultiDistinctSum extends NullableAggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);

private final boolean mustUseMultiDistinctAgg;

public MultiDistinctSum(Expression arg0) {
Expand All @@ -65,8 +56,10 @@ private MultiDistinctSum(boolean mustUseMultiDistinctAgg, boolean distinct,

@Override
public void checkLegalityBeforeTypeCoercion() {
if (child().getDataType().isDateLikeType()) {
throw new AnalysisException("Sum in multi distinct functions do not support Date/Datetime type");
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,16 @@
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** MultiDistinctSum0 */
public class MultiDistinctSum0 extends AggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction, AlwaysNotNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);

private final boolean mustUseMultiDistinctAgg;

public MultiDistinctSum0(Expression arg0) {
Expand All @@ -61,8 +52,10 @@ private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, boolean distinct, Exp

@Override
public void checkLegalityBeforeTypeCoercion() {
if (child().getDataType().isDateLikeType()) {
throw new AnalysisException("Sum0 in multi distinct functions do not support Date/Datetime type");
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum0 requires a numeric or boolean parameter: " + this.toSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ public class Stddev extends NullableAggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class StddevSamp extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE));

/**
Expand Down
Loading

0 comments on commit d56000e

Please sign in to comment.