-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16062][SPARK-15989][SQL] Fix two bugs of Python-only UDTs #13778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f26c8dc
cd80f0e
d22dca8
fc9c106
d603cc2
a0b81ba
4c00bb1
1583fe3
65a33b0
1b751af
87a0953
6065364
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -575,6 +575,41 @@ def check_datatype(datatype): | |
| _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) | ||
| self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) | ||
|
|
||
| def test_simple_udt_in_df(self): | ||
| schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], | ||
| schema=schema) | ||
| df.show() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test only fails when using |
||
|
|
||
| def test_nested_udt_in_df(self): | ||
| schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, need its own named unit test (so that it's easier to identify the problem if the test fails) - unit tests should test only one thing, the thing tested here is |
||
| df = self.spark.createDataFrame( | ||
| [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| schema = StructType().add("key", LongType()).add("val", | ||
| MapType(LongType(), PythonOnlyUDT())) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| def test_complex_nested_udt_in_df(self): | ||
| from pyspark.sql.functions import udf | ||
|
|
||
| schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) | ||
| df = self.spark.createDataFrame( | ||
| [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], | ||
| schema=schema) | ||
| df.collect() | ||
|
|
||
| gd = df.groupby("key").agg({"val": "collect_list"}) | ||
| gd.collect() | ||
| udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) | ||
| gd.select(udf(*gd)).collect() | ||
|
|
||
| def test_udt_with_none(self): | ||
| df = self.spark.range(0, 10, 1, 1) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,6 +206,7 @@ object RowEncoder { | |
| case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) | ||
| case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) | ||
| case _: StructType => ObjectType(classOf[Row]) | ||
| case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) | ||
| case udt: UserDefinedType[_] => ObjectType(udt.userClass) | ||
| } | ||
|
|
||
|
|
@@ -220,9 +221,15 @@ object RowEncoder { | |
| CreateExternalRow(fields, schema) | ||
| } | ||
|
|
||
| private def deserializerFor(input: Expression): Expression = input.dataType match { | ||
| private def deserializerFor(input: Expression): Expression = { | ||
| deserializerFor(input, input.dataType) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that this method is never used?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uh? It is the original deserializerFor method and is used below and above.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry... Confused by the split diff view... |
||
|
|
||
| private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { | ||
| case dt if ScalaReflection.isNativeType(dt) => input | ||
|
|
||
| case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) | ||
|
|
||
| case udt: UserDefinedType[_] => | ||
| val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) | ||
| val udtClass: Class[_] = if (annotation != null) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -346,6 +346,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext | |
| object MapObjects { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
| /** | ||
| * Construct an instance of MapObjects case class. | ||
| * | ||
| * @param function The function applied on the collection elements. | ||
| * @param inputData An expression that when evaluated returns a collection object. | ||
| * @param elementType The data type of elements in the collection. | ||
| */ | ||
| def apply( | ||
| function: Expression => Expression, | ||
| inputData: Expression, | ||
|
|
@@ -433,8 +440,14 @@ case class MapObjects private( | |
| case _ => "" | ||
| } | ||
|
|
||
| // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. | ||
| // When we want to apply MapObjects on it, we have to use it. | ||
| val inputDataType = inputData.dataType match { | ||
| case p: PythonUserDefinedType => p.sqlType | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about Scala UDT?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There could be another UDT inside p.sqlType
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scala UDT is already cover by deserializerFor.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to handle python udf different from scala udf?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. python udf has no userClass. So regular handling of scala udf will be failed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fixed by https://github.com/apache/spark/pull/13778/files#diff-47e9c0787b1c455e5bd4ad7b65df3436R209 . Can you double check it? If we revert this change, will test fail again?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. let me check it.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. The test will be failed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to catch the python udt before passing it to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. let me try it. |
||
| case _ => inputData.dataType | ||
| } | ||
|
|
||
| val (getLength, getLoopVar) = inputData.dataType match { | ||
| val (getLength, getLoopVar) = inputDataType match { | ||
| case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => | ||
| s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" | ||
| case ObjectType(cls) if cls.isArray => | ||
|
|
@@ -448,7 +461,7 @@ case class MapObjects private( | |
| s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" | ||
| } | ||
|
|
||
| val loopNullCheck = inputData.dataType match { | ||
| val loopNullCheck = inputDataType match { | ||
| case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" | ||
| // The element of primitive array will never be null. | ||
| case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in its own test method - it's no longer merely a
test_udtbut rather atest_simple_udt_in_df.