Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,8 @@ public boolean equals(Object other) {
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
} else if (!(other instanceof InternalRow)) {
return false;
} else {
throw new IllegalArgumentException(
"Cannot compare UnsafeRow to " + other.getClass().getName());
}
return false;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,13 @@ class CodegenContext {
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about MapType?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MapType is not comparable. We currently do not support equals() nor hashcode() for MapData. See https://issues.apache.org/jira/browse/SPARK-18134 for a fun discussion on this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: ) That is a fun discussion.

Also accidentally found how Preso did it in a PR: https://github.com/prestodb/presto/pull/2469/files

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do not plan to support MapType, could we add a negative test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in #15956

case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case other => s"$c1.equals($c2)"
case _ =>
throw new IllegalArgumentException(
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
}

/**
Expand Down Expand Up @@ -512,6 +517,11 @@ class CodegenContext {
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
// when comparing unsafe arrays, try equals first as it compares the binary directly
// which is very fast.
if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) {
return 0;
}
int lengthA = a.numElements();
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
Expand Down Expand Up @@ -551,6 +561,11 @@ class CodegenContext {
val funcCode: String =
s"""
public int $compareFunc(InternalRow a, InternalRow b) {
// when comparing unsafe rows, try equals first as it compares the binary directly
// which is very fast.
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
return 0;
}
InternalRow i = null;
$comparisons
return 0;
Expand All @@ -561,7 +576,8 @@ class CodegenContext {
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
throw new IllegalArgumentException(
"cannot generate compare code for un-comparable type: " + dataType.simpleString)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}

protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
}


Expand All @@ -414,17 +416,7 @@ case class EqualTo(left: Expression, right: Expression)

override def symbol: String = "="

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (left.dataType == FloatType) {
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
}
}
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
Expand Down Expand Up @@ -452,15 +444,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (input1 == null || input2 == null) {
false
} else {
if (left.dataType == FloatType) {
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
} else if (left.dataType == DoubleType) {
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
} else if (left.dataType != BinaryType) {
input1 == input2
} else {
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
}
ordering.equiv(input1, input2)
}
}

Expand All @@ -483,8 +467,6 @@ case class LessThan(left: Expression, right: Expression)

override def symbol: String = "<"

private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}

Expand All @@ -497,8 +479,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)

override def symbol: String = "<="

private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}

Expand All @@ -511,8 +491,6 @@ case class GreaterThan(left: Expression, right: Expression)

override def symbol: String = ">"

private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}

Expand All @@ -525,7 +503,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)

override def symbol: String = ">="

private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
}

test("EqualTo on complex type") {
val array = new GenericArrayData(Array(1, 2, 3))
val struct = create_row("a", 1L, array)

val arrayType = ArrayType(IntegerType)
val structType = new StructType()
.add("1", StringType)
.add("2", LongType)
.add("3", ArrayType(IntegerType))

val projection = UnsafeProjection.create(
new StructType().add("array", arrayType).add("struct", structType))

val unsafeRow = projection(InternalRow(array, struct))

val unsafeArray = unsafeRow.getArray(0)
val unsafeStruct = unsafeRow.getStruct(1, 3)

checkEvaluation(EqualTo(
Literal.create(array, arrayType),
Literal.create(unsafeArray, arrayType)), true)

checkEvaluation(EqualTo(
Literal.create(struct, structType),
Literal.create(unsafeStruct, structType)), true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2476,4 +2476,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-18053: ARRAY equality is broken") {
withTable("array_tbl") {
spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl")
assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1)
}
}
}