diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6b069844c3da..0dc42740abcf 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -601,6 +601,24 @@ def test_parquet_with_udt(self): point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_unionAll_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.unionAll(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd..6f092c9b9619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def equals(other: Any): Boolean = other match { + case that: UserDefinedType[_] => this.acceptsType(that) + case _ => false + } } /** @@ -110,4 +115,9 @@ private[sql] class PythonUserDefinedType( ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } + + override def equals(other: Any): Boolean = other match { + case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8d4854b698ed..4aaf32145d00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def equals(other: Any): Boolean = other match { + case that: ExamplePoint => this.x == that.x && this.y == that.y + case _ => false + } +} /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f2ac56072700..8994556cde63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -516,6 +516,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("unionAll should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.unionAll(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + ignore("show") { // This test case is intended ignored, but to make sure it compiles correctly testData.select($"*").show()