diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java index 2ed864ba9a0c30..df211926c02a3b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java @@ -19,61 +19,27 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; +import org.apache.doris.nereids.trees.expressions.functions.SearchSignature; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.BitmapType; -import org.apache.doris.nereids.types.BooleanType; -import org.apache.doris.nereids.types.DateTimeType; -import org.apache.doris.nereids.types.DateTimeV2Type; -import org.apache.doris.nereids.types.DateType; -import org.apache.doris.nereids.types.DateV2Type; -import org.apache.doris.nereids.types.DecimalV2Type; -import org.apache.doris.nereids.types.DecimalV3Type; -import org.apache.doris.nereids.types.DoubleType; -import org.apache.doris.nereids.types.FloatType; -import org.apache.doris.nereids.types.IntegerType; -import org.apache.doris.nereids.types.LargeIntType; -import org.apache.doris.nereids.types.SmallIntType; -import org.apache.doris.nereids.types.StringType; -import org.apache.doris.nereids.types.TimeType; -import org.apache.doris.nereids.types.TimeV2Type; -import org.apache.doris.nereids.types.TinyIntType; -import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; /** * ScalarFunction 'coalesce'. This class is generated by GenerateFunction. */ -public class Coalesce extends ScalarFunction - implements ExplicitlyCastableSignature { - - public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(BooleanType.INSTANCE).varArgs(BooleanType.INSTANCE), - FunctionSignature.ret(TinyIntType.INSTANCE).varArgs(TinyIntType.INSTANCE), - FunctionSignature.ret(SmallIntType.INSTANCE).varArgs(SmallIntType.INSTANCE), - FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE), - FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE), - FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE), - FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE), - FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT), - FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE), - FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE), - FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE), - FunctionSignature.ret(TimeType.INSTANCE).varArgs(TimeType.INSTANCE), - FunctionSignature.ret(TimeV2Type.INSTANCE).varArgs(TimeV2Type.INSTANCE), - FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT), - FunctionSignature.ret(BitmapType.INSTANCE).varArgs(BitmapType.INSTANCE), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE) - ); +public class Coalesce extends ScalarFunction implements CustomSignature { /** * constructor with 1 or more arguments. @@ -106,8 +72,33 @@ public Coalesce withChildren(List children) { } @Override - public List getSignatures() { - return SIGNATURES; + public FunctionSignature customSignature() { + Map> partitioned = getArguments().stream() + .collect(Collectors.partitioningBy( + e -> (e instanceof Literal && ((Literal) e).isStringLikeLiteral()))); + List forFindCommon = partitioned.get(false).stream() + .map(ExpressionTrait::getDataType) + .collect(Collectors.toList()); + Optional commonType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(forFindCommon); + if (!commonType.isPresent()) { + SearchSignature.throwCanNotFoundFunctionException(this.getName(), getArguments()); + return null; + } else { + for (Expression stringLiteral : partitioned.get(true)) { + Optional option = TypeCoercionUtils.characterLiteralTypeCoercion( + ((Literal) stringLiteral).getStringValue(), commonType.get()); + if (!option.isPresent()) { + List commonTypes = Lists.newArrayList(commonType.get(), stringLiteral.getDataType()); + commonType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(commonTypes); + if (!commonType.isPresent()) { + SearchSignature.throwCanNotFoundFunctionException(this.getName(), getArguments()); + } else { + break; + } + } + } + return FunctionSignature.ret(commonType.get()).varArgs(commonType.get()); + } } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java index 152c0f542e3140..a2831e8cd2c13d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java @@ -60,10 +60,10 @@ public void testCoalesce() { assertRewrite(new Coalesce(slot, nonNullableSlot), new Coalesce(slot, nonNullableSlot)); // coalesce(null, null) -> null - assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE), new NullLiteral(BooleanType.INSTANCE)); + assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE), NullLiteral.INSTANCE); // coalesce(null) -> null - assertRewrite(new Coalesce(NullLiteral.INSTANCE), new NullLiteral(BooleanType.INSTANCE)); + assertRewrite(new Coalesce(NullLiteral.INSTANCE), NullLiteral.INSTANCE); // coalesce(non-nullable_slot) -> non-nullable_slot assertRewrite(new Coalesce(nonNullableSlot), nonNullableSlot);