Skip to content

Commit c14d1ba

Browse files
support unionAll for dataframes with UDT columns
1 parent 0784e02 commit c14d1ba

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,24 @@ def test_parquet_with_udt(self):
601601
point = df1.head().point
602602
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
603603

604+
def test_unionAll_with_udt(self):
605+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
606+
row1 = (1.0, ExamplePoint(1.0, 2.0))
607+
row2 = (2.0, ExamplePoint(3.0, 4.0))
608+
schema = StructType([StructField("label", DoubleType(), False),
609+
StructField("point", ExamplePointUDT(), False)])
610+
df1 = self.sqlCtx.createDataFrame([row1], schema)
611+
df2 = self.sqlCtx.createDataFrame([row2], schema)
612+
613+
result = df1.unionAll(df2).orderBy("label").collect()
614+
self.assertEqual(
615+
result,
616+
[
617+
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
618+
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
619+
]
620+
)
621+
604622
def test_column_operators(self):
605623
ci = self.df.key
606624
cs = self.df.value

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
8484

8585
override private[sql] def acceptsType(dataType: DataType) =
8686
this.getClass == dataType.getClass
87+
88+
override def equals(other: Any): Boolean = other match {
89+
case that: UserDefinedType[_] => this.acceptsType(that)
90+
case _ => false
91+
}
8792
}
8893

8994
/**
@@ -110,4 +115,9 @@ private[sql] class PythonUserDefinedType(
110115
("serializedClass" -> serializedPyClass) ~
111116
("sqlType" -> sqlType.jsonValue)
112117
}
118+
119+
override def equals(other: Any): Boolean = other match {
120+
case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
121+
case _ => false
122+
}
113123
}

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
2626
* @param y y coordinate
2727
*/
2828
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
29-
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
29+
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable {
30+
override def equals(other: Any): Boolean = other match {
31+
case that: ExamplePoint => this.x == that.x && this.y == that.y
32+
case _ => false
33+
}
34+
}
3035

3136
/**
3237
* User-defined type for [[ExamplePoint]].

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
516516
}
517517
}
518518

519+
test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
520+
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
521+
val schema1 = StructType(Array(StructField("label", IntegerType, false),
522+
StructField("point", new ExamplePointUDT(), false)))
523+
val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
524+
val schema2 = StructType(Array(StructField("label", IntegerType, false),
525+
StructField("point", new ExamplePointUDT(), false)))
526+
val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
527+
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
528+
529+
checkAnswer(
530+
df1.unionAll(df2).orderBy("label"),
531+
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
532+
)
533+
}
534+
519535
ignore("show") {
520536
// This test case is intended ignored, but to make sure it compiles correctly
521537
testData.select($"*").show()

0 commit comments

Comments
 (0)