-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11743][SQL] Add UserDefinedType support to RowEncoder #9712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2721741
bab5c5b
db644fb
99867ad
a5fdbce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders | |
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.{RandomDataGenerator, Row} | ||
| import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
| @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) | ||
| class ExamplePoint(val x: Double, val y: Double) extends Serializable { | ||
| override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt | ||
| override def equals(that: Any): Boolean = { | ||
| if (that.isInstanceOf[ExamplePoint]) { | ||
| val e = that.asInstanceOf[ExamplePoint] | ||
| (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && | ||
| (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * User-defined type for [[ExamplePoint]]. | ||
| */ | ||
| class ExamplePointUDT extends UserDefinedType[ExamplePoint] { | ||
|
|
||
| override def sqlType: DataType = ArrayType(DoubleType, false) | ||
|
|
||
| override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" | ||
|
|
||
| override def serialize(obj: Any): GenericArrayData = { | ||
| obj match { | ||
| case p: ExamplePoint => | ||
| val output = new Array[Any](2) | ||
| output(0) = p.x | ||
| output(1) = p.y | ||
| new GenericArrayData(output) | ||
| } | ||
| } | ||
|
|
||
| override def deserialize(datum: Any): ExamplePoint = { | ||
| datum match { | ||
| case values: ArrayData => | ||
| new ExamplePoint(values.getDouble(0), values.getDouble(1)) | ||
| } | ||
| } | ||
|
|
||
| override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] | ||
|
|
||
| private[spark] override def asNullable: ExamplePointUDT = this | ||
| } | ||
|
|
||
| class RowEncoderSuite extends SparkFunSuite { | ||
|
|
||
| private val structOfString = new StructType().add("str", StringType) | ||
| private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) | ||
| private val arrayOfString = ArrayType(StringType) | ||
| private val mapOfString = MapType(StringType, StringType) | ||
| private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) | ||
|
|
||
| encodeDecodeTest( | ||
| new StructType() | ||
|
|
@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite { | |
| .add("string", StringType) | ||
| .add("binary", BinaryType) | ||
| .add("date", DateType) | ||
| .add("timestamp", TimestampType)) | ||
| .add("timestamp", TimestampType) | ||
| .add("udt", new ExamplePointUDT, false)) | ||
|
|
||
| encodeDecodeTest( | ||
| new StructType() | ||
|
|
@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite { | |
| .add("structOfArray", new StructType().add("array", arrayOfString)) | ||
| .add("structOfMap", new StructType().add("map", mapOfString)) | ||
| .add("structOfArrayAndMap", | ||
| new StructType().add("array", arrayOfString).add("map", mapOfString))) | ||
| new StructType().add("array", arrayOfString).add("map", mapOfString)) | ||
| .add("structOfUDT", structOfUDT)) | ||
|
|
||
| test(s"encode/decode: arrayOfUDT") { | ||
| val schema = new StructType() | ||
| .add("arrayOfUDT", arrayOfUDT) | ||
|
|
||
| val encoder = RowEncoder(schema) | ||
|
|
||
| val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) | ||
| val row = encoder.toRow(input) | ||
| val convertedBack = encoder.fromRow(row) | ||
| assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) | ||
| } | ||
|
|
||
| test(s"encode/decode: Product") { | ||
| val schema = new StructType() | ||
| .add("structAsProduct", | ||
| new StructType() | ||
| .add("int", IntegerType) | ||
| .add("string", StringType) | ||
| .add("double", DoubleType)) | ||
|
|
||
| val encoder = RowEncoder(schema) | ||
|
|
||
| val input: Row = Row((100, "test", 0.123)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question here. According to the Javadoc of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I found this problem when working on ScalaUDF. ScalaUDF will use So for a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If one of the input parameter is
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have an input parameter mapping to a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. We need the type tag to infer the return type of UDF, and if a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, as I tried, I found a problem is we may not always be able to get the type
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the reason we still support Product for StructType is for backward-compatibility, we did not enforce the inbound type before, someone may reply one it (because it's easier than Row in Scala). |
||
| val row = encoder.toRow(input) | ||
| val convertedBack = encoder.fromRow(row) | ||
| assert(input.getStruct(0) == convertedBack.getStruct(0)) | ||
| } | ||
|
|
||
| private def encodeDecodeTest(schema: StructType): Unit = { | ||
| test(s"encode/decode: ${schema.simpleString}") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why we put this test here instead of adding
arrayOfUDTtype inencodeDecodeTestlikestructOfUDT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. I moved it in #10538.