Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat](Nereids) support fold constant by fe (#40441)(#40772)(#40744)(#40745)(40820) #41837

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ public String getStringValueInFe(FormatOptions options) {
String timeStr = getStringValue();
return timeStr.substring(1, timeStr.length() - 1);
} else {
if (Double.isInfinite(getValue())) {
return Double.toString(getValue());
}
return BigDecimal.valueOf(getValue()).toPlainString();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import com.google.common.collect.Lists;
import org.apache.commons.codec.digest.DigestUtils;

import java.time.DateTimeException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -449,6 +450,8 @@ public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
// If cast is from type coercion, we don't use NULL literal and will throw exception.
throw t;
}
} catch (DateTimeException e) {
return new NullLiteral(dataType);
}
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,4 @@
*/
String name();

/**
* args type
*/
String[] argTypes();

/**
* return type
*/
String returnType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.catalog.Env;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.exceptions.NotSupportedException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
import org.apache.doris.nereids.trees.expressions.functions.executable.ExecutableFunctions;
import org.apache.doris.nereids.trees.expressions.functions.executable.NumericArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.StringArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.TimeRoundSeries;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

Expand All @@ -49,7 +49,7 @@ public enum ExpressionEvaluator {

INSTANCE;

private ImmutableMultimap<String, FunctionInvoker> functions;
private ImmutableMultimap<String, Method> functions;

ExpressionEvaluator() {
registerFunctions();
Expand All @@ -65,23 +65,16 @@ public Expression eval(Expression expression) {
}

String fnName = null;
DataType[] args = null;
DataType ret = expression.getDataType();
if (expression instanceof BinaryArithmetic) {
BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
fnName = arithmetic.getLegacyOperator().getName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof TimestampArithmetic) {
TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
fnName = arithmetic.getFuncName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof BoundFunction) {
BoundFunction function = ((BoundFunction) expression);
fnName = function.getName();
args = new DataType[function.arity()];
for (int i = 0; i < function.children().size(); i++) {
args[i] = function.child(i).getDataType();
}
}

if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) {
Expand All @@ -92,40 +85,104 @@ public Expression eval(Expression expression) {
}
}

return invoke(expression, fnName, args);
return invoke(expression, fnName);
}

private Expression invoke(Expression expression, String fnName, DataType[] args) {
FunctionSignature signature = new FunctionSignature(fnName, args, null);
FunctionInvoker invoker = getFunction(signature);
if (invoker != null) {
private Expression invoke(Expression expression, String fnName) {
Method method = getFunction(fnName, expression.children());
if (method != null) {
try {
return invoker.invoke(expression.children());
} catch (AnalysisException e) {
int varSize = method.getParameterTypes().length;
if (varSize == 0) {
return (Literal) method.invoke(null, expression.children().toArray());
}
boolean hasVarArgs = method.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = expression.children().size();
Class<?>[] parameterTypes = method.getParameterTypes();
Class<?> parameterType = parameterTypes[varSize - 1];
Class<?> componentType = parameterType.getComponentType();
Object varArgs = Array.newInstance(componentType, inputSize - fixedArgsSize);
for (int i = fixedArgsSize; i < inputSize; i++) {
Array.set(varArgs, i - fixedArgsSize, expression.children().get(i));
}
Object[] objects = new Object[fixedArgsSize + 1];
for (int i = 0; i < fixedArgsSize; i++) {
objects[i] = expression.children().get(i);
}
objects[fixedArgsSize] = varArgs;

return (Literal) method.invoke(null, objects);
}
return (Literal) method.invoke(null, expression.children().toArray());
} catch (InvocationTargetException e) {
if (e.getTargetException() instanceof NotSupportedException) {
throw new NotSupportedException(e.getTargetException().getMessage());
} else {
return expression;
}
} catch (IllegalAccessException | IllegalArgumentException e) {
return expression;
}
}
return expression;
}

private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
private boolean canDownCastTo(Class<?> expect, Class<?> input) {
if (DateLiteral.class.isAssignableFrom(expect)
|| DateTimeLiteral.class.isAssignableFrom(expect)) {
return expect.equals(input);
}
return expect.isAssignableFrom(input);
}

if (candidateTypes.length != expectedTypes.length) {
continue;
}
private Method getFunction(String fnName, List<Expression> inputs) {
Collection<Method> expectMethods = functions.get(fnName);
for (Method expect : expectMethods) {
boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
match = false;
break;
int varSize = expect.getParameterTypes().length;
if (varSize == 0) {
if (inputs.size() == 0) {
return expect;
} else {
continue;
}
}
boolean hasVarArgs = expect.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = inputs.size();
if (inputSize <= fixedArgsSize) {
continue;
}
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < fixedArgsSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
}
}
Class<?> varArgsType = expectVarTypes[varSize - 1];
Class<?> varArgType = varArgsType.getComponentType();
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!canDownCastTo(varArgType, inputs.get(i).getClass())) {
match = false;
}
}
} else {
int inputSize = inputs.size();
if (inputSize != varSize) {
continue;
}
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < varSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
}
}
}
if (match) {
return candidate;
return expect;
}
}
return null;
Expand All @@ -135,14 +192,14 @@ private void registerFunctions() {
if (functions != null) {
return;
}
ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new ImmutableMultimap.Builder<>();
ImmutableMultimap.Builder<String, Method> mapBuilder = new ImmutableMultimap.Builder<>();
List<Class<?>> classes = ImmutableList.of(
DateTimeAcquire.class,
DateTimeExtractAndTransform.class,
ExecutableFunctions.class,
DateLiteral.class,
DateTimeArithmetic.class,
NumericArithmetic.class,
StringArithmetic.class,
TimeRoundSeries.class
);
for (Class<?> cls : classes) {
Expand All @@ -159,78 +216,10 @@ private void registerFunctions() {
this.functions = mapBuilder.build();
}

private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder,
private void registerFEFunction(ImmutableMultimap.Builder<String, Method> mapBuilder,
Method method, ExecFunction annotation) {
if (annotation != null) {
String name = annotation.name();
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
DataType[] array = new DataType[argTypes.size()];
for (int i = 0; i < argTypes.size(); i++) {
array[i] = argTypes.get(i);
}
FunctionSignature signature = new FunctionSignature(name, array, returnType);
mapBuilder.put(name, new FunctionInvoker(method, signature));
mapBuilder.put(annotation.name(), method);
}
}

/**
* function invoker.
*/
public static class FunctionInvoker {
private final Method method;
private final FunctionSignature signature;

public FunctionInvoker(Method method, FunctionSignature signature) {
this.method = method;
this.signature = signature;
}

public Method getMethod() {
return method;
}

public FunctionSignature getSignature() {
return signature;
}

public Literal invoke(List<Expression> args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args.toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
}
}

/**
* function signature.
*/
public static class FunctionSignature {
private final String name;
private final DataType[] argTypes;
private final DataType returnType;

public FunctionSignature(String name, DataType[] argTypes, DataType returnType) {
this.name = name;
this.argTypes = argTypes;
this.returnType = returnType;
}

public DataType[] getArgTypes() {
return argTypes;
}

public DataType getReturnType() {
return returnType;
}

public String getName() {
return name;
}
}

}
Loading
Loading