Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
(arguments.isEmpty() || arguments.size() == 1)
? Collections.emptyList()
: arguments.subList(1, arguments.size());
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args);
return PlanUtils.makeOver(
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public class PPLOperandTypes {
private PPLOperandTypes() {}

public static final UDFOperandMetadata NONE = UDFOperandMetadata.wrap(OperandTypes.family());
public static final UDFOperandMetadata OPTIONAL_ANY =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.ANY).or(OperandTypes.family()));
public static final UDFOperandMetadata OPTIONAL_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
Expand All @@ -43,6 +47,10 @@ private PPLOperandTypes() {}
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)));
public static final UDFOperandMetadata ANY_OPTIONAL_TIMESTAMP =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP)));
public static final UDFOperandMetadata INTEGER_INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
public static final UDFOperandMetadata STRING_STRING =
Expand Down Expand Up @@ -121,6 +129,12 @@ private PPLOperandTypes() {}
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));
public static final UDFOperandMetadata ANY_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.ANY)
.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.DATETIME))
.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.STRING)));

public static final UDFOperandMetadata DATETIME_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptTable;
Expand All @@ -22,6 +23,7 @@
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
Expand Down Expand Up @@ -419,4 +421,19 @@ public Void visitInputRef(RexInputRef inputRef) {
visitor.visitEach(rexNodes);
return selectedColumns;
}

/**
* Get a string representation of the argument types expressed in ExprType for error messages.
*
* @param argTypes the list of argument types as {@link RelDataType}
* @return a string in the format [type1,type2,...] representing the argument types
*/
public static String getActualSignature(List<RelDataType> argTypes) {
return "["
+ argTypes.stream()
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
.map(Objects::toString)
.collect(Collectors.joining(","))
+ "]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ public enum BuiltinFunctionName {
.put("stddev", BuiltinFunctionName.STDDEV_POP)
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
// .put("earliest", BuiltinFunctionName.EARLIEST)
// .put("latest", BuiltinFunctionName.LATEST)
.put("earliest", BuiltinFunctionName.EARLIEST)
.put("latest", BuiltinFunctionName.LATEST)
.put("distinct_count_approx", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
.put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
.put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,25 +412,40 @@ public void registerExternalAggOperator(
aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler));
}

public void validateAggFunctionSignature(
BuiltinFunctionName functionName, RexNode field, List<RexNode> argList) {
var implementation = getImplementation(functionName);
validateFunctionArgs(implementation, functionName, field, argList);
}

public RelBuilder.AggCall resolveAgg(
BuiltinFunctionName functionName,
boolean distinct,
RexNode field,
List<RexNode> argList,
CalcitePlanContext context) {
var implementation = aggExternalFunctionRegistry.get(functionName);
if (implementation == null) {
implementation = aggFunctionRegistry.get(functionName);
}
if (implementation == null) {
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
}
var implementation = getImplementation(functionName);

// Validation is done based on original argument types to generate error from user perspective.
validateFunctionArgs(implementation, functionName, field, argList);

var handler = implementation.getValue();
return handler.apply(distinct, field, argList, context);
}

static void validateFunctionArgs(
Pair<CalciteFuncSignature, AggHandler> implementation,
BuiltinFunctionName functionName,
RexNode field,
List<RexNode> argList) {
CalciteFuncSignature signature = implementation.getKey();

List<RelDataType> argTypes = new ArrayList<>();
if (field != null) {
argTypes.add(field.getType());
}
// Currently only PERCENTILE_APPROX and TAKE have additional arguments.

// Currently only PERCENTILE_APPROX, TAKE, EARLIEST, and LATEST have additional arguments.
// Their additional arguments will always come as a map of <argName, value>
List<RelDataType> additionalArgTypes =
argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).collect(Collectors.toList());
Expand All @@ -448,8 +463,18 @@ public RelBuilder.AggCall resolveAgg(
signature.getTypeChecker().getAllowedSignatures(),
getActualSignature(argTypes)));
}
var handler = implementation.getValue();
return handler.apply(distinct, field, argList, context);
}

private Pair<CalciteFuncSignature, AggHandler> getImplementation(
BuiltinFunctionName functionName) {
var implementation = aggExternalFunctionRegistry.get(functionName);
if (implementation == null) {
implementation = aggFunctionRegistry.get(functionName);
}
if (implementation == null) {
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
}
return implementation;
}


Expand Down Expand Up @@ -497,10 +522,10 @@ public RexNode resolve(
}
} catch (Exception e) {
throw new ExpressionEvaluationException(
String.format(
"Cannot resolve function: %s, arguments: %s, caused by: %s",
functionName, getActualSignature(argTypes), e.getMessage()),
e);
String.format(
"Cannot resolve function: %s, arguments: %s, caused by: %s",
functionName, getActualSignature(argTypes), e.getMessage()),
e);
}
StringJoiner allowedSignatures = new StringJoiner(",");
for (var implement : implementList) {
Expand All @@ -510,9 +535,9 @@ functionName, getActualSignature(argTypes), e.getMessage()),
}
}
throw new ExpressionEvaluationException(
String.format(
"%s function expects {%s}, but got %s",
functionName, allowedSignatures, getActualSignature(argTypes)));
String.format(
"%s function expects {%s}, but got %s",
functionName, allowedSignatures, getActualSignature(argTypes)));
}

/**
Expand Down Expand Up @@ -1081,21 +1106,6 @@ void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFuncti
register(functionName, handler, typeChecker);
}

private static RexNode resolveTimeField(List<RexNode> argList, CalcitePlanContext ctx) {
if (argList.isEmpty()) {
// Try to find @timestamp field
var timestampField =
ctx.relBuilder.peek().getRowType().getField("@timestamp", false, false);
if (timestampField == null) {
throw new IllegalArgumentException(
"Default @timestamp field not found. Please specify a time field explicitly.");
}
return ctx.rexBuilder.makeInputRef(timestampField.getType(), timestampField.getIndex());
} else {
return PlanUtils.derefMapCall(argList.get(0));
}
}

void populate() {
registerOperator(MAX, SqlStdOperatorTable.MAX);
registerOperator(MIN, SqlStdOperatorTable.MIN);
Expand Down Expand Up @@ -1125,8 +1135,7 @@ void populate() {
return ctx.relBuilder.count(distinct, null, field);
}
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.COUNT.getOperandTypeChecker(), COUNT.name(), false));
wrapSqlOperandTypeChecker(PPLOperandTypes.OPTIONAL_ANY, COUNT.name(), false));

register(
PERCENTILE_APPROX,
Expand Down Expand Up @@ -1173,20 +1182,22 @@ void populate() {
register(
EARLIEST,
(distinct, field, argList, ctx) -> {
RexNode timeField = resolveTimeField(argList, ctx);
return ctx.relBuilder.aggregateCall(SqlStdOperatorTable.ARG_MIN, field, timeField);
List<RexNode> args = resolveTimeField(argList, ctx);
return UserDefinedFunctionUtils.makeAggregateCall(
SqlStdOperatorTable.ARG_MIN, List.of(field), args, ctx.relBuilder);
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.ARG_MIN.getOperandTypeChecker(), EARLIEST.name(), false));
PPLOperandTypes.ANY_OPTIONAL_TIMESTAMP, EARLIEST.name(), false));

register(
LATEST,
(distinct, field, argList, ctx) -> {
RexNode timeField = resolveTimeField(argList, ctx);
return ctx.relBuilder.aggregateCall(SqlStdOperatorTable.ARG_MAX, field, timeField);
List<RexNode> args = resolveTimeField(argList, ctx);
return UserDefinedFunctionUtils.makeAggregateCall(
SqlStdOperatorTable.ARG_MAX, List.of(field), args, ctx.relBuilder);
},
wrapSqlOperandTypeChecker(
SqlStdOperatorTable.ARG_MAX.getOperandTypeChecker(), LATEST.name(), false));
PPLOperandTypes.ANY_OPTIONAL_TIMESTAMP, EARLIEST.name(), false));

// Register FIRST function - uses document order
register(
Expand All @@ -1210,7 +1221,6 @@ void populate() {
}
}


/**
* Get a string representation of the argument types expressed in ExprType for error messages.
*
Expand All @@ -1226,6 +1236,21 @@ private static String getActualSignature(List<RelDataType> argTypes) {
+ "]";
}

static List<RexNode> resolveTimeField(List<RexNode> argList, CalcitePlanContext ctx) {
if (argList.isEmpty()) {
// Try to find @timestamp field
var timestampField = ctx.relBuilder.peek().getRowType().getField("@timestamp", false, false);
if (timestampField == null) {
throw new IllegalArgumentException(
"Default @timestamp field not found. Please specify a time field explicitly.");
}
return List.of(
ctx.rexBuilder.makeInputRef(timestampField.getType(), timestampField.getIndex()));
} else {
return argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
}
}

/**
* Wraps a {@link SqlOperandTypeChecker} into a {@link PPLTypeChecker} for use in function
* signature validation.
Expand All @@ -1238,42 +1263,42 @@ private static String getActualSignature(List<RelDataType> argTypes) {
private static PPLTypeChecker wrapSqlOperandTypeChecker(
SqlOperandTypeChecker typeChecker, String functionName, boolean isUserDefinedFunction) {
PPLTypeChecker pplTypeChecker;
// Only the composite operand type checker for UDFs are concerned here.
if (isUserDefinedFunction
&& typeChecker instanceof CompositeOperandTypeChecker) {
// UDFs implement their own composite type checkers, which always use OR logic for
// argument types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, false);
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker) {
ImplicitCastOperandTypeChecker implicitCastTypeChecker = (ImplicitCastOperandTypeChecker) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapFamily(implicitCastTypeChecker);
if (typeChecker instanceof ImplicitCastOperandTypeChecker) {
ImplicitCastOperandTypeChecker implicitCastTypeChecker = (ImplicitCastOperandTypeChecker) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapFamily(implicitCastTypeChecker);
} else if (typeChecker instanceof CompositeOperandTypeChecker) {
// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
try {
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, true);
} catch (IllegalArgumentException | UnsupportedOperationException e) {
logger.debug(
String.format(
"Failed to create composite type checker for operator: %s. Will skip its type"
+ " checking",
functionName),
e);
pplTypeChecker = null;
}
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
// UDFs implement their own composite type checkers, which always use OR logic for
// argument
// types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.

// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
try {
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, !isUserDefinedFunction);
} catch (IllegalArgumentException | UnsupportedOperationException e) {
logger.debug(
String.format(
"Failed to create composite type checker for operator: %s. Will skip its type"
+ " checking",
functionName),
e);
pplTypeChecker = null;
}
} else if (typeChecker instanceof SameOperandTypeChecker) {
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
SameOperandTypeChecker comparableTypeChecker = (SameOperandTypeChecker) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker);
SameOperandTypeChecker comparableTypeChecker = (SameOperandTypeChecker) typeChecker;
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker);
} else if (typeChecker instanceof UDFOperandMetadata.UDTOperandMetadata) {
UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata = (UDFOperandMetadata.UDTOperandMetadata) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.getAllowSignatures());
UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata =
(UDFOperandMetadata.UDTOperandMetadata) typeChecker;
pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.getAllowSignatures());
} else if (typeChecker != null) {
pplTypeChecker = PPLTypeChecker.wrapDefault(typeChecker);
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking", functionName);
Expand Down Expand Up @@ -1306,4 +1331,4 @@ private static SqlOperandTypeChecker extractTypeCheckerFromUDF(SqlOperator opera
}
return typeChecker;
}
}
}
Loading
Loading