Skip to content
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,41 @@ def check_datatype(datatype):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))

def test_simple_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in its own test method - it's no longer merely a test_udt but rather a test_simple_udt_in_df.

df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.show()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataFrame.show() gives unnecessary stringification, so this test ends up testing unnecessary stuff (in fact it would fail if the UDT didn't have __str__. I would use collect() to force materialization instead.

Copy link
Member Author

@viirya viirya Jun 22, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only fails when using show() as I mentioned on the JIRA SPARK-16062.


def test_nested_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, need its own named unit test (so that it's easier to identify the problem if the test fails) - unit tests should test only one thing, the thing tested here is test_nested_udt_in_df (perhaps also worthwhile to check Map works?)

df = self.spark.createDataFrame(
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
schema=schema)
df.collect()

schema = StructType().add("key", LongType()).add("val",
MapType(LongType(), PythonOnlyUDT()))
df = self.spark.createDataFrame(
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
schema=schema)
df.collect()

def test_complex_nested_udt_in_df(self):
from pyspark.sql.functions import udf

schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
schema=schema)
df.collect()

gd = df.groupby("key").agg({"val": "collect_list"})
gd.collect()
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
gd.select(udf(*gd)).collect()

def test_udt_with_none(self):
df = self.spark.range(0, 10, 1, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

Expand All @@ -220,9 +221,15 @@ object RowEncoder {
CreateExternalRow(fields, schema)
}

private def deserializerFor(input: Expression): Expression = input.dataType match {
private def deserializerFor(input: Expression): Expression = {
deserializerFor(input, input.dataType)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that this method is never used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh? It is the original deserializerFor method and is used below and above.

Copy link
Contributor

@liancheng liancheng Jun 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry... Confused by the split diff view...


private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match {
case dt if ScalaReflection.isNativeType(dt) => input

case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)

case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
object MapObjects {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of MapObjects case class.
*
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
*/
def apply(
function: Expression => Expression,
inputData: Expression,
Expand Down Expand Up @@ -433,8 +440,14 @@ case class MapObjects private(
case _ => ""
}

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about Scala UDT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be another UDT inside p.sqlType

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala UDT is already cover by deserializerFor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to handle python udf different from scala udf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python udf has no userClass. So regular handling of scala udf will be failed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fixed by https://github.com/apache/spark/pull/13778/files#diff-47e9c0787b1c455e5bd4ad7b65df3436R209 . Can you double check it? If we revert this change, will test fail again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. let me check it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The test will be failed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to catch the python udt before passing it to MapObjects? I'm kind of worried about leaking python udt to a lot of places, we should handle them just in a few places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. let me try it.

case _ => inputData.dataType
}

val (getLength, getLoopVar) = inputData.dataType match {
val (getLength, getLoopVar) = inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
Expand All @@ -448,7 +461,7 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

val loopNullCheck = inputData.dataType match {
val loopNullCheck = inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
Expand Down