Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
* StructType -> org.apache.spark.sql.Row
* StructType -> org.apache.spark.sql.Row (or Product)
* }}}
*/
def apply(i: Int): Any = get(i)
Expand All @@ -177,7 +177,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
* StructType -> org.apache.spark.sql.Row
* StructType -> org.apache.spark.sql.Row (or Product)
* }}}
*/
def get(i: Int): Any
Expand Down Expand Up @@ -306,7 +306,15 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
def getStruct(i: Int): Row = getAs[Row](i)
def getStruct(i: Int): Row = {
// Product and Row both are recoginized as StructType in a Row
val t = get(i)
if (t.isInstanceOf[Product]) {
Row.fromTuple(t.asInstanceOf[Product])
} else {
t.asInstanceOf[Row]
}
}

/**
* Returns the value at position i.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject

case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case TimestampType =>
StaticInvoke(
DateTimeUtils,
Expand Down Expand Up @@ -109,11 +117,16 @@ object RowEncoder {

case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
} else {
"get"
}
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType))
}
CreateStruct(convertedFields)
Expand All @@ -137,6 +150,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

private def constructorFor(schema: StructType): Expression = {
Expand All @@ -155,6 +169,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input

case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

case TimestampType =>
StaticInvoke(
DateTimeUtils,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression {

override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject :: Nil
override def children: Seq[Expression] = arguments.+:(targetObject)

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
Expand Down Expand Up @@ -343,33 +343,35 @@ case class MapObjects(
private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
private lazy val completeFunction = function(loopAttribute)

private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case IntegerType => (i: String) => s".getInt($i)"
case LongType => (i: String) => s".getLong($i)"
case FloatType => (i: String) => s".getFloat($i)"
case DoubleType => (i: String) => s".getDouble($i)"
case ByteType => (i: String) => s".getByte($i)"
case ShortType => (i: String) => s".getShort($i)"
case BooleanType => (i: String) => s".getBoolean($i)"
case StringType => (i: String) => s".getUTF8String($i)"
case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
}

private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
(".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false)
case ArrayType(s: StructType, _) =>
(".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
case ArrayType(a: ArrayType, _) =>
(".numElements()", (i: String) => s".getArray($i)", true)
case ArrayType(IntegerType, _) =>
(".numElements()", (i: String) => s".getInt($i)", true)
case ArrayType(LongType, _) =>
(".numElements()", (i: String) => s".getLong($i)", true)
case ArrayType(FloatType, _) =>
(".numElements()", (i: String) => s".getFloat($i)", true)
case ArrayType(DoubleType, _) =>
(".numElements()", (i: String) => s".getDouble($i)", true)
case ArrayType(ByteType, _) =>
(".numElements()", (i: String) => s".getByte($i)", true)
case ArrayType(ShortType, _) =>
(".numElements()", (i: String) => s".getShort($i)", true)
case ArrayType(BooleanType, _) =>
(".numElements()", (i: String) => s".getBoolean($i)", true)
case ArrayType(StringType, _) =>
(".numElements()", (i: String) => s".getUTF8String($i)", false)
case ArrayType(_: MapType, _) =>
(".numElements()", (i: String) => s".getMap($i)", false)
case ArrayType(t, _) =>
val (sqlType, primitiveElement) = t match {
case m: MapType => (m, false)
case s: StructType => (s, false)
case s: StringType => (s, false)
case udt: UserDefinedType[_] => (udt.sqlType, false)
case o => (o, true)
}
(".numElements()", itemAccessorMethod(sqlType), primitiveElement)
}

override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
override def equals(that: Any): Boolean = {
if (that.isInstanceOf[ExamplePoint]) {
val e = that.asInstanceOf[ExamplePoint]
(this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
(this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
} else {
false
}
}
}

/**
* User-defined type for [[ExamplePoint]].
*/
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def sqlType: DataType = ArrayType(DoubleType, false)

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
}

override def deserialize(datum: Any): ExamplePoint = {
datum match {
case values: ArrayData =>
new ExamplePoint(values.getDouble(0), values.getDouble(1))
}
}

override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]

private[spark] override def asNullable: ExamplePointUDT = this
}

class RowEncoderSuite extends SparkFunSuite {

private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)

encodeDecodeTest(
new StructType()
Expand All @@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite {
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType))
.add("timestamp", TimestampType)
.add("udt", new ExamplePointUDT, false))

encodeDecodeTest(
new StructType()
Expand All @@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite {
.add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfMap", new StructType().add("map", mapOfString))
.add("structOfArrayAndMap",
new StructType().add("array", arrayOfString).add("map", mapOfString)))
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))

test(s"encode/decode: arrayOfUDT") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we put this test here instead of adding arrayOfUDT type in encodeDecodeTest like structOfUDT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I moved it in #10538.

val schema = new StructType()
.add("arrayOfUDT", arrayOfUDT)

val encoder = RowEncoder(schema)

val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
}

test(s"encode/decode: Product") {
val schema = new StructType()
.add("structAsProduct",
new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType))

val encoder = RowEncoder(schema)

val input: Row = Row((100, "test", 0.123))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question here. According to the Javadoc of Row, a user should use and only use Row for StructType field. It's ok to support Product too, but do we have a reason for this? Is it needed for the UDT stuff? Sorry I'm not familiar with UDT handling, it will be good if you can explain it in detail, thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I found this problem when working on ScalaUDF. ScalaUDF will use schemaFor to obtain catalyst type for UDF input and output. The catalyst type returned by schemaFor for a Product is StructType. It is reasonable as we don't have other type to represent Product as I see.

So for a StructType field in an external Row, both Row and Product are possible values. When we call extractorsFor on the external Row, externalDataTypeFor will return ObjectType(classOf[Row]) for this field. But the get accessor on the inputObject (i.e., the Row) will possibly return a Product for the ScalaUDF case and an exception will be thrown.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one of the input parameter is Tuple2, then we need to use the encoder to decode a catalyst value to external value, i.e. decode an InternalRow object to Tuple2 object. I think this is hard for a RowEncoder(your change only makes it possible to encode a Product into InternalRow, but not vice versa), we should use ProductEncoder for this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have an input parameter mapping to a StructType field in an InternalRow, we will use Row as its input type. E.g., sqlContext.udf.register("udfFunc", (ns: Row) => { (ns.getInt(0), ns.getString(1)) }). But we can't use Row as output type for an UDF. Because we can still get the input schema of ScalaUDF's children expressions later if we can't infer input types correctly by using schemaFor. However, the output types of the UDF can be only inferred by schemaFor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. We need the type tag to infer the return type of UDF, and if a Row is returned, there is no type information we can get. How about we use ProductEncoder or FlatEncoder for the return value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, as I tried, I found a problem is we may not always be able to get the type T needed to construct ProductEncoder and FlatEncoder. Even we can get it, we can't keep it in ScalaUDF due to serialization issue. So I think using RowEncoder is more reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason we still support Product for StructType is for backward-compatibility, we did not enforce the inbound type before, someone may reply one it (because it's easier than Row in Scala).

val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getStruct(0) == convertedBack.getStruct(0))
}

private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
Expand Down