diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 0257bbddc3a..49366bc030f 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -38,6 +38,7 @@ import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare; +import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compareNullSafe; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.evalNullability; import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; @@ -156,6 +157,19 @@ ExpressionTransformResult visitComparator(Predicate predicate) { } } + @Override + ExpressionTransformResult visitNullSafeComparator(Predicate predicate) { + switch (predicate.getName()) { + case "<=>": + return new ExpressionTransformResult( + transformBinaryComparator(predicate), + BooleanType.BOOLEAN); + default: + throw DeltaErrors.unsupportedExpression( + predicate, Optional.of("unsupported expression encountered")); + } + } + @Override ExpressionTransformResult visitLiteral(Literal literal) { // nothing to validate or rewrite @@ -445,6 +459,27 @@ ColumnVector visitComparator(Predicate predicate) { return new DefaultBooleanVector(numRows, Optional.of(nullability), result); } + @Override + ColumnVector visitNullSafeComparator(Predicate predicate) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate); + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + int[] compareResult = compareNullSafe(argResults.leftResult, argResults.rightResult); + switch (predicate.getName()) { + case "<=>": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] == 0; + } + break; + default: + throw DeltaErrors.unsupportedExpression( + predicate, + Optional.of("unsupported expression encountered")); + } + + return new DefaultBooleanVector(numRows, Optional.empty(), result); + } + @Override ColumnVector visitLiteral(Literal literal) { DataType dataType = literal.getDataType(); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index e078a7a7cd9..d64d7cc9935 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -33,6 +33,20 @@ * Utility methods used by the default expression evaluator. */ class DefaultExpressionUtils { + + static final Comparator binaryComparator = (leftOp, rightOp) -> { + int i = 0; + while (i < leftOp.length && i < rightOp.length) { + if (leftOp[i] != rightOp[i]) { + return Byte.compare(leftOp[i], rightOp[i]); + } + i++; + } + return Integer.compare(leftOp.length, rightOp.length); + }; + + static final Comparator bigDecimalComparator = Comparator.naturalOrder(); + static final Comparator stringComparator = Comparator.naturalOrder(); private DefaultExpressionUtils() {} /** @@ -127,6 +141,53 @@ static int[] compare(ColumnVector left, ColumnVector right) { return result; } + static int[] compareNullSafe(ColumnVector left, ColumnVector right) { + checkArgument( + left.getSize() == right.getSize(), + "Left and right operand have different vector sizes."); + DataType dataType = left.getDataType(); + int numRows = left.getSize(); + int[] result = new int[numRows]; + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (left.isNullAt(rowId) && right.isNullAt(rowId)) { + result[rowId] = 0; + } else if (left.isNullAt(rowId) || right.isNullAt(rowId)) { + result[rowId] = 1; + } else { + if (dataType instanceof BooleanType) { + result[rowId] = + Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)); + } else if (dataType instanceof ByteType) { + result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId)); + } else if (dataType instanceof ShortType) { + result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId)); + } else if (dataType instanceof IntegerType || dataType instanceof DateType) { + result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId)); + } else if (dataType instanceof LongType || dataType instanceof TimestampType) { + result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId)); + } else if (dataType instanceof FloatType) { + result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId)); + } else if (dataType instanceof DoubleType) { + result[rowId] = Double.compare( + left.getDouble(rowId), right.getDouble(rowId)); + } else if (dataType instanceof DecimalType) { + result[rowId] = bigDecimalComparator.compare( + left.getDecimal(rowId), right.getDecimal(rowId)); + } else if (dataType instanceof StringType) { + result[rowId] = stringComparator.compare( + left.getString(rowId), right.getString(rowId)); + } else if (dataType instanceof BinaryType) { + result[rowId] = binaryComparator.compare( + left.getBinary(rowId), right.getBinary(rowId)); + } else { + throw new UnsupportedOperationException(dataType + " can not be compared."); + } + } + } + return result; + } + + static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) { for (int rowId = 0; rowId < left.getSize(); rowId++) { if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { @@ -202,19 +263,10 @@ static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) } static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { - Comparator comparator = (leftOp, rightOp) -> { - int i = 0; - while (i < leftOp.length && i < rightOp.length) { - if (leftOp[i] != rightOp[i]) { - return Byte.compare(leftOp[i], rightOp[i]); - } - i++; - } - return Integer.compare(leftOp.length, rightOp.length); - }; for (int rowId = 0; rowId < left.getSize(); rowId++) { if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); + result[rowId] = binaryComparator.compare( + left.getBinary(rowId), right.getBinary(rowId)); } } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index bd219f55fda..8627bb4c6cd 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -41,6 +41,8 @@ abstract class ExpressionVisitor { abstract R visitComparator(Predicate predicate); + abstract R visitNullSafeComparator(Predicate predicate); + abstract R visitLiteral(Literal literal); abstract R visitColumn(Column column); @@ -95,6 +97,8 @@ private R visitScalarExpression(ScalarExpression expression) { case ">": case ">=": return visitComparator(new Predicate(name, children)); + case "<=>": + return visitNullSafeComparator(new Predicate(name, children)); case "ELEMENT_AT": return visitElementAt(expression); case "NOT": diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index 53f2c7ae09f..0bfbc67d601 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -305,7 +305,7 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa "Coalesce is only supported for boolean type expressions") } - test("evaluate expression: comparators (=, <, <=, >, >=)") { + test("evaluate expression: comparators (=, <, <=, >, >=, <=>)") { // Literals for each data type from the data type value range, used as inputs to comparator // (small, big, small, null) val literals = Seq( @@ -350,7 +350,8 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa "<=" -> Seq(true, false, true, null, null, null), ">" -> Seq(false, true, false, null, null, null), ">=" -> Seq(false, true, true, null, null, null), - "=" -> Seq(false, false, true, null, null, null) + "=" -> Seq(false, false, true, null, null, null), + "<=>" -> Seq(false, false, true, false, false, true) ) literals.foreach {