Skip to content

Commit bd9f3bb

Browse files
committed
Merge comparators and their IP variants so that coercion works for IP comparison
- when not merging, ip comparing will also pass the type checker of Calcite's comparators Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent b126b87 commit bd9f3bb

File tree

5 files changed

+65
-40
lines changed

5 files changed

+65
-40
lines changed

core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,11 @@ public class CoercionUtils {
8686
}
8787

8888
private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
89-
// Implement the logic to check if fromType can be cast to toType
90-
// This could involve checking if the types are compatible, or if a cast is possible
91-
// For example, you might check if fromType is a subtype of toType, or if a conversion exists
9289
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
9390
if (!argType.shouldCast(targetType)) {
9491
return arg;
9592
}
9693

97-
// If the arg is string
9894
if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
9995
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
10096
}

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,13 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
106106
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
107107
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");
108108

109-
// IP comparing functions
110-
public static final SqlOperator NOT_EQUALS_IP =
111-
CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP");
112-
public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP");
113-
public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP");
114-
public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP");
115-
public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP");
116-
public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP");
109+
// Comparing functions
110+
public static final SqlOperator NOT_EQUALS = CompareIpFunction.notEquals().toUDF("NOT_EQUALS");
111+
public static final SqlOperator EQUALS = CompareIpFunction.equals().toUDF("EQUALS");
112+
public static final SqlOperator GREATER = CompareIpFunction.greater().toUDF("GREATER");
113+
public static final SqlOperator GTE = CompareIpFunction.greaterOrEquals().toUDF("GTE");
114+
public static final SqlOperator LESS = CompareIpFunction.less().toUDF("LESS");
115+
public static final SqlOperator LTE = CompareIpFunction.lessOrEquals().toUDF("LTE");
117116

118117
// Condition function
119118
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@
231231
import org.apache.calcite.rex.RexBuilder;
232232
import org.apache.calcite.rex.RexLambda;
233233
import org.apache.calcite.rex.RexNode;
234+
import org.apache.calcite.sql.SqlOperandCountRange;
234235
import org.apache.calcite.sql.SqlOperator;
235236
import org.apache.calcite.sql.fun.SqlLibraryOperators;
236237
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -240,6 +241,7 @@
240241
import org.apache.calcite.sql.type.OperandTypes;
241242
import org.apache.calcite.sql.type.ReturnTypes;
242243
import org.apache.calcite.sql.type.SameOperandTypeChecker;
244+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
243245
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
244246
import org.apache.calcite.sql.type.SqlTypeFamily;
245247
import org.apache.calcite.sql.type.SqlTypeName;
@@ -577,7 +579,11 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
577579
}
578580

579581
// Only the composite operand type checker for UDFs are concerned here.
580-
if (operator instanceof SqlUserDefinedFunction
582+
if (BuiltinFunctionName.COMPARATORS.contains(functionName)) {
583+
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
584+
register(
585+
functionName, wrapWithComparableTypeChecker(operator, SqlOperandCountRanges.of(2)));
586+
} else if (operator instanceof SqlUserDefinedFunction
581587
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
582588
// UDFs implement their own composite type checkers, which always use OR logic for
583589
// argument
@@ -596,9 +602,11 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
596602
register(
597603
functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
598604
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
599-
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
600605
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
601-
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
606+
register(
607+
functionName,
608+
wrapWithComparableTypeChecker(
609+
operator, comparableTypeChecker.getOperandCountRange()));
602610
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
603611
register(
604612
functionName,
@@ -681,7 +689,7 @@ public PPLTypeChecker getTypeChecker() {
681689
}
682690

683691
private static FunctionImp wrapWithComparableTypeChecker(
684-
SqlOperator operator, SameOperandTypeChecker typeChecker) {
692+
SqlOperator operator, SqlOperandCountRange countRange) {
685693
return new FunctionImp() {
686694
@Override
687695
public RexNode resolve(RexBuilder builder, RexNode... args) {
@@ -690,7 +698,7 @@ public RexNode resolve(RexBuilder builder, RexNode... args) {
690698

691699
@Override
692700
public PPLTypeChecker getTypeChecker() {
693-
return PPLTypeChecker.wrapComparable(typeChecker);
701+
return PPLTypeChecker.comparable(countRange);
694702
}
695703
};
696704
}
@@ -727,12 +735,12 @@ public PPLTypeChecker getTypeChecker() {
727735

728736
void populate() {
729737
// register operators for comparison
730-
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
731-
registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP, SqlStdOperatorTable.EQUALS);
732-
registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP, SqlStdOperatorTable.GREATER_THAN);
733-
registerOperator(GTE, PPLBuiltinOperators.GTE_IP, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
734-
registerOperator(LESS, PPLBuiltinOperators.LESS_IP, SqlStdOperatorTable.LESS_THAN);
735-
registerOperator(LTE, PPLBuiltinOperators.LTE_IP, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
738+
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS);
739+
registerOperator(EQUAL, PPLBuiltinOperators.EQUALS);
740+
registerOperator(GREATER, PPLBuiltinOperators.GREATER);
741+
registerOperator(GTE, PPLBuiltinOperators.GTE);
742+
registerOperator(LESS, PPLBuiltinOperators.LESS);
743+
registerOperator(LTE, PPLBuiltinOperators.LTE);
736744

737745
// Register std operator
738746
registerOperator(AND, SqlStdOperatorTable.AND);
@@ -1048,7 +1056,7 @@ void populate() {
10481056
builder.makeCall(SqlStdOperatorTable.EQUALS, arg1, arg2),
10491057
builder.makeNullLiteral(arg1.getType()),
10501058
arg1),
1051-
PPLTypeChecker.wrapComparable((SameOperandTypeChecker) OperandTypes.SAME_SAME)));
1059+
PPLTypeChecker.comparable(SqlOperandCountRanges.of(2))));
10521060
register(
10531061
IS_EMPTY,
10541062
createFunctionImpWithTypeChecker(

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import org.apache.calcite.rel.type.RelDataType;
1818
import org.apache.calcite.rel.type.RelDataTypeField;
1919
import org.apache.calcite.sql.SqlIntervalQualifier;
20+
import org.apache.calcite.sql.SqlOperandCountRange;
2021
import org.apache.calcite.sql.parser.SqlParserPos;
2122
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
2223
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
2324
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
24-
import org.apache.calcite.sql.type.SameOperandTypeChecker;
2525
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
2626
import org.apache.calcite.sql.type.SqlTypeFamily;
2727
import org.apache.calcite.sql.type.SqlTypeName;
@@ -248,11 +248,11 @@ public List<List<ExprType>> getParameterTypes() {
248248

249249
@RequiredArgsConstructor
250250
class PPLComparableTypeChecker implements PPLTypeChecker {
251-
private final SameOperandTypeChecker innerTypeChecker;
251+
private final SqlOperandCountRange countRange;
252252

253253
@Override
254254
public boolean checkOperandTypes(List<RelDataType> types) {
255-
if (!innerTypeChecker.getOperandCountRange().isValidCount(types.size())) {
255+
if (!countRange.isValidCount(types.size())) {
256256
return false;
257257
}
258258
// Check comparability of consecutive operands
@@ -313,8 +313,8 @@ private static boolean isComparable(RelDataType type1, RelDataType type2) {
313313

314314
@Override
315315
public String getAllowedSignatures() {
316-
int min = innerTypeChecker.getOperandCountRange().getMin();
317-
int max = innerTypeChecker.getOperandCountRange().getMax();
316+
int min = countRange.getMin();
317+
int max = countRange.getMax();
318318
final String typeName = "COMPARABLE_TYPE";
319319
if (min == -1 || max == -1) {
320320
// If the range is unbounded, we cannot provide a specific signature
@@ -347,9 +347,7 @@ public boolean checkOperandTypes(List<RelDataType> types) {
347347
}
348348
RelDataType type1 = types.get(0);
349349
RelDataType type2 = types.get(1);
350-
return areIpAndStringTypes(type1, type2)
351-
|| areIpAndStringTypes(type2, type1)
352-
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
350+
return type1 instanceof ExprIPType && type2 instanceof ExprIPType;
353351
}
354352

355353
@Override
@@ -363,10 +361,6 @@ public String getAllowedSignatures() {
363361
public List<List<ExprType>> getParameterTypes() {
364362
return List.of(List.of(ExprCoreType.IP, ExprCoreType.IP));
365363
}
366-
367-
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
368-
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
369-
}
370364
}
371365

372366
class PPLCidrTypeChecker implements PPLTypeChecker {
@@ -392,7 +386,7 @@ public String getAllowedSignatures() {
392386

393387
@Override
394388
public List<List<ExprType>> getParameterTypes() {
395-
return List.of(List.of(ExprCoreType.IP, ExprCoreType.IP));
389+
return List.of(List.of(ExprCoreType.IP, ExprCoreType.STRING));
396390
}
397391
}
398392

@@ -467,8 +461,8 @@ static PPLCompositeTypeChecker wrapComposite(
467461
return new PPLCompositeTypeChecker(typeChecker);
468462
}
469463

470-
static PPLComparableTypeChecker wrapComparable(SameOperandTypeChecker typeChecker) {
471-
return new PPLComparableTypeChecker(typeChecker);
464+
static PPLComparableTypeChecker comparable(SqlOperandCountRange countRange) {
465+
return new PPLComparableTypeChecker(countRange);
472466
}
473467

474468
// Util Functions

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import org.apache.calcite.adapter.enumerable.NullPolicy;
1111
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
1212
import org.apache.calcite.linq4j.tree.Expression;
13+
import org.apache.calcite.linq4j.tree.ExpressionType;
1314
import org.apache.calcite.linq4j.tree.Expressions;
1415
import org.apache.calcite.rex.RexCall;
1516
import org.apache.calcite.sql.type.ReturnTypes;
1617
import org.apache.calcite.sql.type.SqlReturnTypeInference;
18+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
1719
import org.opensearch.sql.data.model.ExprIpValue;
20+
import org.opensearch.sql.data.type.ExprCoreType;
1821
import org.opensearch.sql.expression.function.ImplementorUDF;
1922
import org.opensearch.sql.expression.function.UDFOperandMetadata;
2023

@@ -66,7 +69,8 @@ public SqlReturnTypeInference getReturnTypeInference() {
6669

6770
@Override
6871
public UDFOperandMetadata getOperandMetadata() {
69-
return new UDFOperandMetadata.IPOperandMetadata();
72+
// Its type checker will be assigned according to the function name.
73+
return null;
7074
}
7175

7276
public static class CompareImplementor implements NotNullImplementor {
@@ -79,6 +83,22 @@ public CompareImplementor(ComparisonType comparisonType) {
7983
@Override
8084
public Expression implement(
8185
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
86+
87+
if (!containsIpOperands(call)) {
88+
// Call built-in compare function for non-IP operands
89+
ExpressionType expressionType =
90+
switch (comparisonType) {
91+
case EQUALS -> ExpressionType.Equal;
92+
case NOT_EQUALS -> ExpressionType.NotEqual;
93+
case LESS -> ExpressionType.LessThan;
94+
case LESS_OR_EQUAL -> ExpressionType.LessThanOrEqual;
95+
case GREATER -> ExpressionType.GreaterThan;
96+
case GREATER_OR_EQUAL -> ExpressionType.GreaterThanOrEqual;
97+
};
98+
return Expressions.makeBinary(
99+
expressionType, translatedOperands.get(0), translatedOperands.get(1));
100+
}
101+
82102
Expression compareResult =
83103
Expressions.call(
84104
CompareImplementor.class,
@@ -89,6 +109,14 @@ public Expression implement(
89109
return generateComparisonExpression(compareResult, comparisonType);
90110
}
91111

112+
private static boolean containsIpOperands(RexCall call) {
113+
var left = call.getOperands().get(0);
114+
var right = call.getOperands().get(1);
115+
var leftType = OpenSearchTypeFactory.convertRelDataTypeToExprType(left.getType());
116+
var rightType = OpenSearchTypeFactory.convertRelDataTypeToExprType(right.getType());
117+
return leftType == ExprCoreType.IP || rightType == ExprCoreType.IP;
118+
}
119+
92120
private static Expression generateComparisonExpression(
93121
Expression compareResult, ComparisonType comparisonType) {
94122
return switch (comparisonType) {

0 commit comments

Comments
 (0)