From e1b5deebe715479125c8878f0c90a55dc9ab3e85 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Jul 2018 03:42:04 +0000 Subject: [PATCH 01/21] Aggregator should be able to use Option of Product encoder. --- .../catalyst/encoders/ExpressionEncoder.scala | 11 +++- .../spark/sql/DatasetAggregatorSuite.scala | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cbea3c017a26..1b357698d2ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -43,12 +43,19 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](): ExpressionEncoder[T] = { + // Constructs an encoder for top-level row. + def apply[T : TypeTag](): ExpressionEncoder[T] = apply(topLevel = true) + + /** + * @param topLevel whether the encoders to construct are for top-level row. + */ + def apply[T : TypeTag](topLevel: Boolean): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { + // For non top-level encodes, we allow using Option of Product type. + if (topLevel && ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( "Cannot create encoder for Option of Product type, because Product type is represented " + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40..d31d6d345a7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder(topLevel = false) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,17 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } From 80506f4e98184ccd66dbaac14ec52d69c358020d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 13 Jul 2018 04:40:55 +0000 Subject: [PATCH 02/21] Enable top-level Option of Product encoders. --- .../spark/sql/catalyst/ScalaReflection.scala | 92 +++++++++++-------- .../catalyst/encoders/ExpressionEncoder.scala | 14 +-- .../sql/catalyst/ScalaReflectionSuite.scala | 74 ++++++++++----- .../spark/sql/DatasetAggregatorSuite.scala | 1 + .../org/apache/spark/sql/DatasetSuite.scala | 27 ++++-- 5 files changed, 127 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4543bba8f6ed..3f2505d5d689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,12 +135,20 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag]: Expression = { + def deserializerFor[T : TypeTag](topLevel: Boolean): Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - val expr = deserializerFor(tpe, None, walkedTypePath) - val Schema(_, nullable) = schemaFor(tpe) + val Schema(dataType, tpeNullable) = schemaFor(tpe) + val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && + definedByConstructorParams(tpe) + val (optTypePath, nullable) = if (isOptionOfProduct && topLevel) { + // Top-level Option of Product is encoded as single struct column at top-level row. + (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) + } else { + (None, tpeNullable) + } + val expr = deserializerFor(tpe, optTypePath, walkedTypePath) if (nullable) { expr } else { @@ -148,6 +156,40 @@ object ScalaReflection extends ScalaReflection { } } + /** + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. + case _ => UpCast(expr, expected, walkedTypePath) + } + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + path: Option[Expression], + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + private def deserializerFor( tpe: `Type`, path: Option[Expression], @@ -161,17 +203,6 @@ object ScalaReflection extends ScalaReflection { upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal( - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(GetColumnByOrdinal(ordinal, dataType)) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType @@ -182,28 +213,6 @@ object ScalaReflection extends ScalaReflection { } } - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - def upCastToExpectedType( - expr: Expression, - expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. - case _ => UpCast(expr, expected, walkedTypePath) - } - tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -389,7 +398,7 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - Some(addToPathOrdinal(i, dataType, newTypePath)), + Some(addToPathOrdinal(path, i, dataType, newTypePath)), newTypePath) } else { deserializerFor( @@ -431,11 +440,18 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : TypeTag]( + inputObject: Expression, + topLevel: Boolean): CreateNamedStruct = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { + case i @ expressions.If(_, _, _: CreateNamedStruct) + if tpe.dealias <:< localTypeOf[Option[_]] && + definedByConstructorParams(tpe) && topLevel => + // We encode top-level Option of Product as a single struct column. + CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1b357698d2ec..861451012ff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,16 +54,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - // For non top-level encodes, we allow using Option of Product type. - if (topLevel && ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) @@ -75,8 +65,8 @@ object ExpressionEncoder { // doesn't allow top-level row to be null, only its columns can be null. AssertNotNull(inputObject, Seq("top level Product input object")) } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T] + val serializer = ScalaReflection.serializerFor[T](nullSafeInput, topLevel) + val deserializer = ScalaReflection.deserializerFor[T](topLevel) val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 353b8344658f..0cb188ef9bb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Expression, If, IsNull, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance, WrapOption} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -281,7 +281,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false)) + 0, ObjectType(list.getClass), nullable = false), topLevel = true) assert(serializer.children.size == 2) assert(serializer.children.head.isInstanceOf[Literal]) assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) @@ -291,57 +291,57 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]] + val listDeserializer = deserializerFor[List[Int]](topLevel = true) assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false)) + 0, ObjectType(classOf[Queue[Int]]), nullable = false), topLevel = true) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]] + val queueDeserializer = deserializerFor[Queue[Int]](topLevel = true) assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false), topLevel = true) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]](topLevel = true) assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false), topLevel = true) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]] + val mapDeserializer = deserializerFor[Map[Int, Int]](topLevel = true) assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false), topLevel = true) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]](topLevel = true) assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false), topLevel = true) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]](topLevel = true) assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) - val deserializer = deserializerFor[SpecialCharAsFieldData] + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false), topLevel = true) + val deserializer = deserializerFor[SpecialCharAsFieldData](topLevel = true) assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int](topLevel = true).isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String](topLevel = true).isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,8 +371,38 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)](topLevel = true)) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)](topLevel = true)) == 1) + assert(numberOfCheckedArguments( + deserializerFor[(java.lang.Integer, java.lang.Integer)](topLevel = true)) == 0) + } + + test("SPARK-24762: serializer for Option of Product") { + val optionOfProduct = Some((1, "a")) + val topLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = true) + val nonTopLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = false) + + topLevelSerializer match { + case CreateNamedStruct(Seq(Literal(_, _), If(_, _, optEncoder))) => + assert(optEncoder.semanticEquals(nonTopLevelSerializer)) + case _ => + fail("top-level Option of Product should be encoded as single struct column.") + } + } + + test("SPARK-24762: deserializer for Option of Product") { + val topLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = true) + val nonTopLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = false) + .asInstanceOf[WrapOption] + + topLevelDeserializer match { + case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), optType) => + assert(n.cls == nonTopLevelDeserializer.child.asInstanceOf[NewInstance].cls) + assert(optType == nonTopLevelDeserializer.optType) + case _ => + fail("top-level Option of Product should be decoded from a single struct column.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index d31d6d345a7a..33241671dbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -437,6 +437,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { OptionBooleanIntData("bob", Some((true, 1))), OptionBooleanIntData("bob", Some((false, 2))), OptionBooleanIntData("bob", None)).toDF() + val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ce8db99d4e2f..843eb224dfb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1253,15 +1253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1498,6 +1489,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } + + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true) + )), nullable = true) + )) + + assert(ds.schema == schema) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed3d5cb697b10af2e2cf4c78ab521d4d0b2f3c9b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 04:26:28 +0000 Subject: [PATCH 03/21] Remove topLevel parameter. --- .../spark/sql/catalyst/ScalaReflection.scala | 10 ++- .../catalyst/encoders/ExpressionEncoder.scala | 12 +--- .../sql/catalyst/ScalaReflectionSuite.scala | 68 +++++++++--------- .../aggregate/TypedAggregateExpression.scala | 72 +++++++++++++++++-- .../spark/sql/DatasetAggregatorSuite.scala | 2 +- 5 files changed, 107 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3f2505d5d689..ebd2d3bf0dc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,14 +135,14 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag](topLevel: Boolean): Expression = cleanUpReflectionObjects { + def deserializerFor[T : TypeTag](): Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil val Schema(dataType, tpeNullable) = schemaFor(tpe) val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) - val (optTypePath, nullable) = if (isOptionOfProduct && topLevel) { + val (optTypePath, nullable) = if (isOptionOfProduct) { // Top-level Option of Product is encoded as single struct column at top-level row. (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) } else { @@ -441,15 +441,13 @@ object ScalaReflection extends ScalaReflection { * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ def serializerFor[T : TypeTag]( - inputObject: Expression, - topLevel: Boolean): CreateNamedStruct = cleanUpReflectionObjects { + inputObject: Expression): CreateNamedStruct = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case i @ expressions.If(_, _, _: CreateNamedStruct) - if tpe.dealias <:< localTypeOf[Option[_]] && - definedByConstructorParams(tpe) && topLevel => + if tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) => // We encode top-level Option of Product as a single struct column. CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 861451012ff7..a90137d0029d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -43,13 +43,7 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - // Constructs an encoder for top-level row. - def apply[T : TypeTag](): ExpressionEncoder[T] = apply(topLevel = true) - - /** - * @param topLevel whether the encoders to construct are for top-level row. - */ - def apply[T : TypeTag](topLevel: Boolean): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe @@ -65,8 +59,8 @@ object ExpressionEncoder { // doesn't allow top-level row to be null, only its columns can be null. AssertNotNull(inputObject, Seq("top level Product input object")) } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput, topLevel) - val deserializer = ScalaReflection.deserializerFor[T](topLevel) + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) + val deserializer = ScalaReflection.deserializerFor[T]() val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 0cb188ef9bb7..da9f2f2d0929 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -281,7 +281,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false), topLevel = true) + 0, ObjectType(list.getClass), nullable = false)) assert(serializer.children.size == 2) assert(serializer.children.head.isInstanceOf[Literal]) assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) @@ -291,57 +291,57 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]](topLevel = true) + val listDeserializer = deserializerFor[List[Int]]() assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[Queue[Int]]), nullable = false)) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]](topLevel = true) + val queueDeserializer = deserializerFor[Queue[Int]]() assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]](topLevel = true) + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]() assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]](topLevel = true) + val mapDeserializer = deserializerFor[Map[Int, Int]]() assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]](topLevel = true) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]() assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]](topLevel = true) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]() assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false), topLevel = true) - val deserializer = deserializerFor[SpecialCharAsFieldData](topLevel = true) + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData]() assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int](topLevel = true).isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String](topLevel = true).isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int]().isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String]().isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,36 +371,34 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)](topLevel = true)) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)](topLevel = true)) == 1) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]()) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]()) == 1) assert(numberOfCheckedArguments( - deserializerFor[(java.lang.Integer, java.lang.Integer)](topLevel = true)) == 0) + deserializerFor[(java.lang.Integer, java.lang.Integer)]()) == 0) } test("SPARK-24762: serializer for Option of Product") { val optionOfProduct = Some((1, "a")) - val topLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = true) - val nonTopLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = false) - - topLevelSerializer match { - case CreateNamedStruct(Seq(Literal(_, _), If(_, _, optEncoder))) => - assert(optEncoder.semanticEquals(nonTopLevelSerializer)) + val serializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true)) + + serializer match { + case CreateNamedStruct(Seq(_: Literal, If(_, _, encoder: CreateNamedStruct))) => + val fields = encoder.flatten + assert(fields.length == 2) + assert(fields(0).dataType == IntegerType) + assert(fields(1).dataType == StringType) case _ => fail("top-level Option of Product should be encoded as single struct column.") } } test("SPARK-24762: deserializer for Option of Product") { - val topLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = true) - val nonTopLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = false) - .asInstanceOf[WrapOption] - - topLevelDeserializer match { - case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), optType) => - assert(n.cls == nonTopLevelDeserializer.child.asInstanceOf[NewInstance].cls) - assert(optType == nonTopLevelDeserializer.optType) + val deserializer = deserializerFor[Option[(Int, String)]]() + + deserializer match { + case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => + assert(n.cls == classOf[Tuple2[Int, String]]) case _ => fail("top-level Option of Product should be decoded from a single struct column.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6d44890704f4..c22d0b19a437 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,25 +19,85 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance, WrapOption} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object TypedAggregateExpression { + + // Checks if given encoder is for `Option[Product]`. + def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = { + // Only Option[Product] is non-flat. + encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat + } + + /** + * Flattens serializers and deserializer of given encoder. We only flatten encoder + * of `Option[Product]` class. + */ + def flattenOptProductEncoder(encoder: ExpressionEncoder[_]): ExpressionEncoder[_] = { + val serializer = encoder.serializer + val deserializer = encoder.deserializer + + assert(serializer.length == 1, + "We can only flatten encoder of Option of Product class which has single serializer.") + + val flattenSerializers = serializer(0).collect { + case c: CreateNamedStruct => c.flatten + }.head + + val flattenDeserializer = deserializer match { + case w @ WrapOption(If(_, _, child: NewInstance), optType) => + val newInstance = child match { + case oldNewInstance: NewInstance => + val newArguments = oldNewInstance.arguments.zipWithIndex.map { case (arg, idx) => + arg match { + case a @ AssertNotNull( + UpCast(GetStructField( + child @ GetColumnByOrdinal(0, _), _, _), dt, walkedTypePath), _) => + a.copy(child = UpCast(GetColumnByOrdinal(idx, dt), dt, walkedTypePath.tail)) + } + } + oldNewInstance.copy(arguments = newArguments) + } + w.copy(child = newInstance) + case _ => + throw new AnalysisException( + "On top of deserializer of Option[Product] should be `WrapOption`.") + } + + // `Option[Product]` is encoded as single column of struct type in a row. + val newSchema = encoder.schema.asInstanceOf[StructType].fields(0) + .dataType.asInstanceOf[StructType] + encoder.copy(serializer = flattenSerializers, deserializer = flattenDeserializer, + schema = newSchema) + } + def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { - val bufferEncoder = encoderFor[BUF] + val rawBufferEncoder = encoderFor[BUF] + + val bufferEncoder = if (isOptProductEncoder(rawBufferEncoder)) { + flattenOptProductEncoder(rawBufferEncoder) + } else { + rawBufferEncoder + } val bufferSerializer = bufferEncoder.namedExpressions - val outputEncoder = encoderFor[OUT] + val rawOutputEncoder = encoderFor[OUT] + val outputEncoder = if (isOptProductEncoder(rawOutputEncoder)) { + flattenOptProductEncoder(rawOutputEncoder) + } else { + rawOutputEncoder + } val outputType = if (outputEncoder.flat) { outputEncoder.schema.head.dataType } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 33241671dbcf..0446bd9097b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -218,7 +218,7 @@ case class OptionBooleanIntAggregator(colName: String) override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder - def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder(topLevel = false) + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() } class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { From 5f95bd0cf1bd308c7df55c41caef7a9f19368f5d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 04:42:33 +0000 Subject: [PATCH 04/21] Remove useless change. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/ScalaReflectionSuite.scala | 27 +++++++++---------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bc11191e6959..fbcb1ada1a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,7 +135,7 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag](): Expression = cleanUpReflectionObjects { + def deserializerFor[T : TypeTag]: Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index a90137d0029d..0a1c23886159 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,7 +60,7 @@ object ExpressionEncoder { AssertNotNull(inputObject, Seq("top level Product input object")) } val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T]() + val deserializer = ScalaReflection.deserializerFor[T] val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index da9f2f2d0929..750f0b03e46a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -291,7 +291,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]]() + val listDeserializer = deserializerFor[List[Int]] assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } @@ -301,7 +301,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[Queue[Int]]), nullable = false)) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]]() + val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer @@ -309,7 +309,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]() + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } @@ -318,7 +318,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]]() + val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap @@ -326,7 +326,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]() + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} @@ -334,14 +334,14 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]() + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) - val deserializer = deserializerFor[SpecialCharAsFieldData]() + val deserializer = deserializerFor[SpecialCharAsFieldData] assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int]().isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String]().isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,10 +371,9 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]()) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]()) == 1) - assert(numberOfCheckedArguments( - deserializerFor[(java.lang.Integer, java.lang.Integer)]()) == 0) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } test("SPARK-24762: serializer for Option of Product") { @@ -394,7 +393,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-24762: deserializer for Option of Product") { - val deserializer = deserializerFor[Option[(Int, String)]]() + val deserializer = deserializerFor[Option[(Int, String)]] deserializer match { case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => From a4f04055b2ba22f371663565710328791942855a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 14:38:16 +0000 Subject: [PATCH 05/21] Add more tests. --- .../aggregate/TypedAggregateExpression.scala | 2 +- .../TypedAggregateExpressionSuite.scala | 63 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index c22d0b19a437..27a50a270b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -43,7 +43,7 @@ object TypedAggregateExpression { * Flattens serializers and deserializer of given encoder. We only flatten encoder * of `Option[Product]` class. */ - def flattenOptProductEncoder(encoder: ExpressionEncoder[_]): ExpressionEncoder[_] = { + def flattenOptProductEncoder[T](encoder: ExpressionEncoder[T]): ExpressionEncoder[T] = { val serializer = encoder.serializer val deserializer = encoder.deserializer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala new file mode 100644 index 000000000000..f54557b1e0f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class TypedAggregateExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private def testOptProductEncoder(encoder: ExpressionEncoder[_], expected: Boolean): Unit = { + assert(TypedAggregateExpression.isOptProductEncoder(encoder) == expected) + } + + test("check an encoder is for option of product") { + testOptProductEncoder(encoderFor[Int], false) + testOptProductEncoder(encoderFor[(Long, Long)], false) + testOptProductEncoder(encoderFor[Option[Int]], false) + testOptProductEncoder(encoderFor[Option[(Int, Long)]], true) + testOptProductEncoder(encoderFor[Option[SimpleCaseClass]], true) + } + + test("flatten encoders of option of product") { + // Option[Product] is encoded as a struct column in a row. + val optProductEncoder: ExpressionEncoder[Option[(Int, Long)]] = encoderFor[Option[(Int, Long)]] + val optProductSchema = StructType(StructField("value", StructType( + StructField("_1", IntegerType) :: StructField("_2", LongType) :: Nil)) :: Nil) + + assert(optProductEncoder.schema.length == 1) + assert(DataType.equalsIgnoreCaseAndNullability(optProductEncoder.schema, optProductSchema)) + + val flattenEncoder = TypedAggregateExpression.flattenOptProductEncoder(optProductEncoder) + .resolveAndBind() + assert(flattenEncoder.schema.length == 2) + assert(DataType.equalsIgnoreCaseAndNullability(flattenEncoder.schema, + optProductSchema.fields(0).dataType)) + + val row = flattenEncoder.toRow(Some((1, 2L))) + val expected = flattenEncoder.fromRow(row) + assert(Some((1, 2L)) == expected) + } +} + +case class SimpleCaseClass(a: Int) From c1f798f7e9cba0d04223eed06f1b1f547ec29dc5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Aug 2018 01:52:01 +0000 Subject: [PATCH 06/21] Add test. --- .../org/apache/spark/sql/DatasetSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e292795d6b7c..ce8c5ea6cc00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1506,6 +1506,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) } test("SPARK-23034 show rdd names in RDD scan nodes") { From 0f029b0a28700334dc6334f1ad89b3124f235a51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 Oct 2018 04:40:07 +0000 Subject: [PATCH 07/21] Improve code comments. --- .../spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- .../aggregate/TypedAggregateExpression.scala | 14 ++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6e06a755b6f8..2ce7b02afd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -143,7 +143,8 @@ object ScalaReflection extends ScalaReflection { val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) val (optTypePath, nullable) = if (isOptionOfProduct) { - // Top-level Option of Product is encoded as single struct column at top-level row. + // Because we encode top-level Option[Product] as a struct at the first column of the row, + // we add zero ordinal as the path to access it when to deserialize it. (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) } else { (None, tpeNullable) @@ -448,7 +449,10 @@ object ScalaReflection extends ScalaReflection { serializerFor(inputObject, tpe, walkedTypePath) match { case i @ expressions.If(_, _, _: CreateNamedStruct) if tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) => - // We encode top-level Option of Product as a single struct column. + // When we are going to serialize an Option[Product] at top-level of row, because + // Spark doesn't support top-level row as null, we encode the Option[Product] as a + // struct at the first column of the row. So here we add an extra named struct wrapping + // the serialized Option[Product] which is the first and only column named `value`. CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 27a50a270b09..1cca43e77785 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -35,7 +35,7 @@ object TypedAggregateExpression { // Checks if given encoder is for `Option[Product]`. def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = { - // Only Option[Product] is non-flat. + // For all Option[_] classes, only Option[Product] is reported as not flat. encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat } @@ -47,8 +47,10 @@ object TypedAggregateExpression { val serializer = encoder.serializer val deserializer = encoder.deserializer + // This is just a sanity check. Encoders of Option[Product] has only one `CreateNamedStruct` + // serializer expression. assert(serializer.length == 1, - "We can only flatten encoder of Option of Product class which has single serializer.") + "We only flatten encoder of Option[Product] class which has single serializer.") val flattenSerializers = serializer(0).collect { case c: CreateNamedStruct => c.flatten @@ -74,8 +76,7 @@ object TypedAggregateExpression { "On top of deserializer of Option[Product] should be `WrapOption`.") } - // `Option[Product]` is encoded as single column of struct type in a row. - val newSchema = encoder.schema.asInstanceOf[StructType].fields(0) + val newSchema = encoder.schema.fields(0) .dataType.asInstanceOf[StructType] encoder.copy(serializer = flattenSerializers, deserializer = flattenDeserializer, schema = newSchema) @@ -85,6 +86,11 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val rawBufferEncoder = encoderFor[BUF] + // When `BUF` or `OUT` is an Option[Product], we need to flatten serializers and deserializer + // of original encoder. It is because we wrap serializers of Option[Product] inside an extra + // struct in order to support encoding of Option[Product] at top-level row. But here we use + // the encoder to encode Option[Product] for a column, we need to get rid of this extra + // struct. val bufferEncoder = if (isOptProductEncoder(rawBufferEncoder)) { flattenOptProductEncoder(rawBufferEncoder) } else { From 84f3ce07f2f6a9236bd27f927fbb877e937f6917 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 15 Oct 2018 09:55:03 +0000 Subject: [PATCH 08/21] Refactoring ExpressionEncoder. --- .../scala/org/apache/spark/sql/Encoders.scala | 8 +- .../sql/catalyst/JavaTypeInference.scala | 78 +++---- .../spark/sql/catalyst/ScalaReflection.scala | 206 +++++++++-------- .../catalyst/encoders/ExpressionEncoder.scala | 207 +++++++++++------- .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/ScalaReflectionSuite.scala | 125 +++++------ .../encoders/ExpressionEncoderSuite.scala | 6 +- .../scala/org/apache/spark/sql/Dataset.scala | 10 +- .../spark/sql/KeyValueGroupedDataset.scala | 4 +- .../aggregate/TypedAggregateExpression.scala | 90 +------- .../FlatMapGroupsWithStateExecHelper.scala | 2 +- .../spark/sql/DatasetAggregatorSuite.scala | 52 ----- .../org/apache/spark/sql/DatasetSuite.scala | 44 +--- .../TypedAggregateExpressionSuite.scala | 63 ------ 14 files changed, 353 insertions(+), 554 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index b47ec0b72c63..8a30c81912fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -203,12 +203,10 @@ object Encoders { validatePublicClass[T]() ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - serializer = Seq( + objSerializer = EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - deserializer = + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo), + objDeserializer = DecodeUsingSerializer[T]( Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), classTag[T], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3ecc137c8cd7..6a276b3055ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -187,26 +187,23 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to deserialize an internal row to an object of java bean - * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal + * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed + * using `UnresolvedExtractValue`. */ def deserializerFor(beanClass: Class[_]): Expression = { - deserializerFor(TypeToken.of(beanClass), None) + val typeToken = TypeToken.of(beanClass) + deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) } - private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) + def addToPath(part: String): Expression = UnresolvedExtractValue(path, + expressions.Literal(part)) typeToken.getRawType match { - case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + case c if !inferExternalType(c).isInstanceOf[ObjectType] => path case c if c == classOf[java.lang.Short] || c == classOf[java.lang.Integer] || @@ -219,7 +216,7 @@ object JavaTypeInference { c, ObjectType(c), "valueOf", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Date] => @@ -227,7 +224,7 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Timestamp] => @@ -235,14 +232,14 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(path, "toString", ObjectType(classOf[String])) case c if c == classOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) case c if c.isArray => val elementType = c.getComponentType @@ -258,12 +255,12 @@ object JavaTypeInference { } primitiveMethod.map { method => - Invoke(getPath, method, ObjectType(c)) + Invoke(path, method, ObjectType(c)) }.getOrElse { Invoke( MapObjects( - p => deserializerFor(typeToken.getComponentType, Some(p)), - getPath, + p => deserializerFor(typeToken.getComponentType, p), + path, inferDataType(elementType)._1), "array", ObjectType(c)) @@ -272,8 +269,8 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) MapObjects( - p => deserializerFor(et, Some(p)), - getPath, + p => deserializerFor(et, p), + path, inferDataType(et)._1, customCollectionCls = Some(c)) @@ -285,8 +282,8 @@ object JavaTypeInference { val keyData = Invoke( MapObjects( - p => deserializerFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), + p => deserializerFor(keyType, p), + Invoke(path, "keyArray", ArrayType(keyDataType)), keyDataType), "array", ObjectType(classOf[Array[Any]])) @@ -294,8 +291,8 @@ object JavaTypeInference { val valueData = Invoke( MapObjects( - p => deserializerFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), + p => deserializerFor(valueType, p), + Invoke(path, "valueArray", ArrayType(valueDataType)), valueDataType), "array", ObjectType(classOf[Array[Any]])) @@ -312,7 +309,7 @@ object JavaTypeInference { other, ObjectType(other), "valueOf", - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, returnNullable = false) case other => @@ -321,7 +318,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, addToPath(fieldName)) val setter = if (nullable) { constructor } else { @@ -333,28 +330,23 @@ object JavaTypeInference { val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(other)), - result - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(other)), result - } + ) } } /** - * Returns an expression for serializing an object of the given type to an internal row. + * Returns an expression for serializing an object of the given type to a Spark SQL + * representation. The input object is located at ordinal 0 of a row, i.e., + * `BoundReference(0, _)`. */ - def serializerFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) - serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { - case expressions.If(_, _, s: CreateNamedStruct) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + serializerFor(nullSafeInput, TypeToken.of(beanClass)) } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2ce7b02afd8f..6491e7dc0649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -125,38 +125,6 @@ object ScalaReflection extends ScalaReflection { case _ => false } - /** - * Returns an expression that can be used to deserialize an input row to an object of type `T` - * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. - * - * When used on a primitive type, the constructor will instead default to extracting the value - * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling resolve/bind with a new schema. - */ - def deserializerFor[T : TypeTag]: Expression = cleanUpReflectionObjects { - val tpe = localTypeOf[T] - val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - val Schema(dataType, tpeNullable) = schemaFor(tpe) - val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && - definedByConstructorParams(tpe) - val (optTypePath, nullable) = if (isOptionOfProduct) { - // Because we encode top-level Option[Product] as a struct at the first column of the row, - // we add zero ordinal as the path to access it when to deserialize it. - (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) - } else { - (None, tpeNullable) - } - val expr = deserializerFor(tpe, optTypePath, walkedTypePath) - if (nullable) { - expr - } else { - AssertNotNull(expr, walkedTypePath) - } - } - /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't @@ -168,9 +136,7 @@ object ScalaReflection extends ScalaReflection { * This method help us "remember" the required data type by adding a `UpCast`. Note that we * only need to do this for leaf nodes. */ - def upCastToExpectedType( - expr: Expression, - expected: DataType, + private def upCastToExpectedType(expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _: ArrayType => expr @@ -179,43 +145,63 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal( - path: Option[Expression], - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(GetColumnByOrdinal(ordinal, dataType)) - upCastToExpectedType(newPath, dataType, walkedTypePath) + /** + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of + * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using + * `UnresolvedExtractValue`. + * + * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this + * deserializer expression when using it. + */ + def deserializerForType(tpe: `Type`): Expression = { + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil + val Schema(dataType, nullable) = schemaFor(tpe) + + // Assumes we are deserializing the first column of a row. + val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, + walkedTypePath) + + val expr = deserializerFor(tpe, input, walkedTypePath) + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } } + /** + * Returns an expression that can be used to deserialize an input expression to an object of type + * `T` with a compatible schema. + * + * @param tpe The `Type` of deserialized object. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + */ private def deserializerFor( tpe: `Type`, - path: Option[Expression], + path: Expression, walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute.quoted(part)) + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = { - val dataType = schemaFor(tpe).dataType - if (path.isDefined) { - path.get - } else { - upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - } + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + path: Expression, + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = GetStructField(path, ordinal) + upCastToExpectedType(newPath, dataType, walkedTypePath) } tpe.dealias match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -226,44 +212,44 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => @@ -271,25 +257,25 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) + Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), returnNullable = false) case t if t <:< localTypeOf[Array[_]] => @@ -301,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -309,7 +295,7 @@ object ScalaReflection extends ScalaReflection { } } - val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) if (elementNullable) { @@ -341,7 +327,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -356,16 +342,16 @@ object ScalaReflection extends ScalaReflection { classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - UnresolvedMapObjects(mapFunction, getPath, Some(cls)) + UnresolvedMapObjects(mapFunction, path, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t CatalystToExternalMap( - p => deserializerFor(keyType, Some(p), walkedTypePath), - p => deserializerFor(valueType, Some(p), walkedTypePath), - getPath, + p => deserializerFor(keyType, p, walkedTypePath), + p => deserializerFor(valueType, p, walkedTypePath), + path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -375,7 +361,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -384,7 +370,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -399,12 +385,12 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - Some(addToPathOrdinal(path, i, dataType, newTypePath)), + addToPathOrdinal(path, i, dataType, newTypePath), newTypePath) } else { deserializerFor( fieldType, - Some(addToPath(fieldName, dataType, newTypePath)), + addToPath(fieldName, dataType, newTypePath), newTypePath) } @@ -417,20 +403,17 @@ object ScalaReflection extends ScalaReflection { val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(cls)), newInstance - } + ) } } /** - * Returns an expression for serializing an object of type T to an internal row. + * Returns an expression for serializing an object of type T to Spark SQL representation. The + * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -441,25 +424,34 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag]( - inputObject: Expression): CreateNamedStruct = cleanUpReflectionObjects { - val tpe = localTypeOf[T] + def serializerForType(tpe: `Type`, + cls: RuntimeClass): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - serializerFor(inputObject, tpe, walkedTypePath) match { - case i @ expressions.If(_, _, _: CreateNamedStruct) - if tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) => - // When we are going to serialize an Option[Product] at top-level of row, because - // Spark doesn't support top-level row as null, we encode the Option[Product] as a - // struct at the first column of the row. So here we add an extra named struct wrapping - // the serialized Option[Product] which is the first and only column named `value`. - CreateNamedStruct(expressions.Literal("value") :: i :: Nil) - case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + + // The input object to `ExpressionEncoder` is located at first column of an row. + val inputObject = BoundReference(0, dataTypeFor(tpe), + nullable = !cls.isPrimitive) + + serializerFor(inputObject, tpe, walkedTypePath) } - /** Helper for extracting internal fields from a case class. */ + /** + * Returns an expression for serializing the value of an input expression into Spark SQL + * internal representation. + * + * The expression generated by this method will be used by `ExpressionEncoder` as serializer + * to convert a JVM object to Spark SQL representation. + * + * The returned serializer generally converts a JVM object to corresponding Spark SQL + * representation. For example, `Seq[_]` is converted to a Spark SQL array, `Product` is + * converted to a Spark SQL struct. + * + * If input object is not of ObjectType, it means that the input object is already in a form + * of Spark's internal representation. We simply return the input object. + * + * For unsupported types, an `UnsupportedOperationException` will be thrown. + */ private def serializerFor( inputObject: Expression, tpe: `Type`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0a1c23886159..938b5948406a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -18,17 +18,18 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.reflect.runtime.universe.{`Type`, RuntimeClass, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -43,31 +44,26 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { + def apply[T : TypeTag](): ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror - val tpe = typeTag[T].in(mirror).tpe - + val tpe = ScalaReflection.localTypeOf[T] val cls = mirror.runtimeClass(tpe) - val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive) - val nullSafeInput = if (flat) { - inputObject - } else { - // For input object of Product type, we can't encode it to row if it's null, as Spark SQL - // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(inputObject, Seq("top level Product input object")) + if (ScalaReflection.optionOfProductType(tpe)) { + throw new UnsupportedOperationException( + "Cannot create encoder for Option of Product type, because Product type is represented " + + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + + "You can wrap your type with Tuple1 if you do want top level null Product objects, " + + "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + + "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T] - val schema = serializer.dataType + val serializer = ScalaReflection.serializerForType(tpe, cls) + val deserializer = ScalaReflection.deserializerForType(tpe) new ExpressionEncoder[T]( - schema, - flat, - serializer.flatten, + serializer, deserializer, ClassTag[T](cls)) } @@ -77,14 +73,12 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val serializer = JavaTypeInference.serializerFor(beanClass) - val deserializer = JavaTypeInference.deserializerFor(beanClass) + val objSerializer = JavaTypeInference.serializerFor(beanClass) + val objDeserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( - schema.asInstanceOf[StructType], - flat = false, - serializer.flatten, - deserializer, + objSerializer, + objDeserializer, ClassTag[T](beanClass)) } @@ -94,75 +88,61 @@ object ExpressionEncoder { * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + if (encoders.length > 22) { + throw new RuntimeException("Can't construct a tuple encoder for more than 22 encoders.") + } + encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { case (e, i) => - val (dataType, nullable) = if (e.flat) { - e.schema.head.dataType -> e.schema.head.nullable - } else { - e.schema -> true - } - StructField(s"_${i + 1}", dataType, nullable) + StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.zipWithIndex.map { case (enc, index) => - val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val serializers = encoders.zipWithIndex.map { case (enc, index) => + val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct + assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + + s"there are ${boundRefs.size}") + + val originalInputObject = boundRefs.head val newInputObject = Invoke( BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", - originalInputObject.dataType) + originalInputObject.dataType, + returnNullable = originalInputObject.nullable) - val newSerializer = enc.serializer.map(_.transformUp { + val newSerializer = enc.objSerializer.transformUp { case b: BoundReference if b == originalInputObject => newInputObject - }) - - val serializerExpr = if (enc.flat) { - newSerializer.head - } else { - // For non-flat encoder, the input object is not top level anymore after being combined to - // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and - // null check to handle null case correctly. - // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is - // not able to handle the case when the input tuple is null. This is not a problem as there - // is a check to make sure the input object won't be null. However, if this encoder is used - // to create a bigger tuple encoder, the original input object becomes a filed of the new - // input tuple and can be null. So instead of creating a struct directly here, we should add - // a null/None check and return a null struct if the null/None check fails. - val struct = CreateStruct(newSerializer) - val nullCheck = Or( - IsNull(newInputObject), - Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) - If(nullCheck, Literal.create(null, struct.dataType), struct) } - Alias(serializerExpr, s"_${index + 1}")() + + Alias(newSerializer, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.deserializer.transform { - case g: GetColumnByOrdinal => g.copy(ordinal = index) - } + val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } + .distinct + assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + + val input = GetStructField(GetColumnByOrdinal(0, schema), index) + val newDeserializer = enc.objDeserializer.transformUp { + case GetColumnByOrdinal(0, _) => input + } + if (schema(index).nullable) { + If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) } else { - val input = GetColumnByOrdinal(index, enc.schema) - val deserialized = enc.deserializer.transformUp { - case UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - UnresolvedExtractValue(input, Literal(nameParts.head)) - case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) - } - If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) + newDeserializer } } + val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), + Literal.create(null, schema), CreateStruct(serializers)) val deserializer = NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( - schema, - flat = false, serializer, deserializer, ClassTag(cls)) @@ -203,21 +183,88 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param serializer A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param deserializer An expression that will construct an object given an [[InternalRow]]. + * @param objSerializer An expression that can be used to encode a raw object to corresponding + * Spark SQL representation that can be a primitive column, array, map or a + * struct. This represents how Spark SQL generally serializes an object of + * type `T`. + * @param objDeserializer An expression that will construct an object given a Spark SQL + * representation. This represents how Spark SQL generally deserializes + * a serialized value in Spark SQL representation back to an object of + * type `T`. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( - schema: StructType, - flat: Boolean, - serializer: Seq[Expression], - deserializer: Expression, + objSerializer: Expression, + objDeserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(serializer.size == 1) + /** + * A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]: + * 1. If `serializer` encodes a raw object to a struct, we directly use the `serializer`. + * 2. For other cases, we create a struct to wrap the `serializer`. + */ + val serializer: Seq[NamedExpression] = { + val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] + val clsName = clsTag.runtimeClass.getCanonicalName + + if (serializedAsStruct) { + val nullSafeSerializer = objSerializer.transformUp { + case r: BoundReference => + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(r, Seq("top level Product input object")) + } + nullSafeSerializer match { + case If(_, _, s: CreateNamedStruct) => s + case s: CreateNamedStruct => s + case _ => + throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") + } + } else { + // For other input objects like primitive, array, map, etc., we construct a struct to wrap + // the serializer which is a column of an row. + CreateNamedStruct(Literal("value") :: objSerializer :: Nil) + } + }.flatten + + /** + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`. + * of the same name as the constructor arguments. + * + * For complex objects that are encoded to structs, Fields of the struct will be extracted using + * `GetColumnByOrdinal` with corresponding ordinal. + */ + val deserializer: Expression = { + val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] + + if (serializedAsStruct) { + // We serialized this kind of objects to root-level row. The input of general deserializer + // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to + // transform attributes accessors. + objDeserializer.transform { + case UnresolvedExtractValue(GetColumnByOrdinal(0, _), + Literal(part: UTF8String, StringType)) => + UnresolvedAttribute.quoted(part.toString) + case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) => + GetColumnByOrdinal(ordinal, dt) + case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n + case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i + } + } else { + // For other input objects like primitive, array, map, etc., we deserialize the first column + // of a row to the object. + objDeserializer + } + } + + // The schema after converting `T` to a Spark SQL row. This schema is dependent on the given + // serialier. + val schema: StructType = StructType(serializer.map { s => + StructField(s.name, s.dataType, s.nullable) + }) // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This @@ -249,7 +296,7 @@ case class ExpressionEncoder[T]( analyzer.checkAnalysis(analyzedPlan) val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer val bound = BindReferences.bindReference(resolved, attrs) - copy(deserializer = bound) + copy(objDeserializer = bound) } @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3340789398f9..b65fc2252807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -58,12 +58,10 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) - val deserializer = deserializerFor(schema) + val serializer = serializerFor(inputObject, schema) + val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) new ExpressionEncoder[Row]( - schema, - flat = false, - serializer.asInstanceOf[CreateNamedStruct].flatten, + serializer, deserializer, ClassTag(cls)) } @@ -235,13 +233,13 @@ object RowEncoder { case udt: UserDefinedType[_] => ObjectType(udt.userClass) } - private def deserializerFor(schema: StructType): Expression = { + private def deserializerFor(input: Expression, schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { case p: PythonUserDefinedType => p.sqlType case other => other } - deserializerFor(GetColumnByOrdinal(i, dt)) + deserializerFor(GetStructField(input, i)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 6be89108d472..93704866f079 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Expression, If, IsNull, Literal, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance, WrapOption} +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String case class PrimitiveData( intField: Int, @@ -263,81 +262,84 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) - val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false)) - assert(serializer.children.size == 2) - assert(serializer.children.head.isInstanceOf[Literal]) - assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) - assert(serializer.children.last.isInstanceOf[NewInstance]) - assert(serializer.children.last.asInstanceOf[NewInstance] + val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]], + classOf[List[Int]]) + assert(serializer.isInstanceOf[NewInstance]) + assert(serializer.asInstanceOf[NewInstance] .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]] + val listDeserializer = deserializerForType(ScalaReflection.localTypeOf[List[Int]]) assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue - val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false)) - assert(queueSerializer.dataType.head.dataType == + val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]], + classOf[Queue[Int]]) + assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]] + val queueDeserializer = deserializerForType(ScalaReflection.localTypeOf[Queue[Int]]) assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer - val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) - assert(arrayBufferSerializer.dataType.head.dataType == + val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]], + classOf[ArrayBuffer[Int]]) + assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] + val arrayBufferDeserializer = deserializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { - val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) - assert(mapSerializer.dataType.head.dataType == + val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]], + classOf[Map[Int, Int]]) + assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]] + val mapDeserializer = deserializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap - val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) - assert(hashMapSerializer.dataType.head.dataType == + val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]], + classOf[HashMap[Int, Int]]) + assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + val hashMapDeserializer = deserializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} - val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) - assert(linkedHashMapSerializer.dataType.head.dataType == + val linkedHashMapSerializer = serializerForType( + ScalaReflection.localTypeOf[LHMap[Long, String]], + classOf[LHMap[Long, String]]) + assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + val linkedHashMapDeserializer = deserializerForType( + ScalaReflection.localTypeOf[LHMap[Long, String]]) assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { - val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) - val deserializer = deserializerFor[SpecialCharAsFieldData] + val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData], + classOf[SpecialCharAsFieldData]).collect { + case If(_, _, s: CreateNamedStruct) => s + }.head + val deserializer = deserializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData]) assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") - val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { - case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + val newInstance = deserializer.collect { case n: NewInstance => n }.head + + val argumentsFields = newInstance.arguments.flatMap { _.collect { + case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString }} - assert(argumentsFields(0) == Seq("field.1")) - assert(argumentsFields(1) == Seq("field 2")) + assert(argumentsFields(0) == "field.1") + assert(argumentsFields(1) == "field 2") } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) + assert(deserializerForType(ScalaReflection.localTypeOf[Int]).isInstanceOf[AssertNotNull]) + assert(!deserializerForType(ScalaReflection.localTypeOf[String]).isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -351,38 +353,15 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-23835: add null check to non-nullable types in Tuples") { def numberOfCheckedArguments(deserializer: Expression): Int = { - assert(deserializer.isInstanceOf[NewInstance]) - deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) - } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) - } - - test("SPARK-24762: serializer for Option of Product") { - val optionOfProduct = Some((1, "a")) - val serializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true)) - - serializer match { - case CreateNamedStruct(Seq(_: Literal, If(_, _, encoder: CreateNamedStruct))) => - val fields = encoder.flatten - assert(fields.length == 2) - assert(fields(0).dataType == IntegerType) - assert(fields(1).dataType == StringType) - case _ => - fail("top-level Option of Product should be encoded as single struct column.") - } - } - - test("SPARK-24762: deserializer for Option of Product") { - val deserializer = deserializerFor[Option[(Int, String)]] - - deserializer match { - case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => - assert(n.cls == classOf[Tuple2[Int, String]]) - case _ => - fail("top-level Option of Product should be decoded from a single struct column.") + val newInstance = deserializer.collect { case n: NewInstance => n}.head + newInstance.arguments.count(_.isInstanceOf[AssertNotNull]) } + assert(numberOfCheckedArguments( + deserializerForType(ScalaReflection.localTypeOf[(Double, Double)])) == 2) + assert(numberOfCheckedArguments( + deserializerForType(ScalaReflection.localTypeOf[(java.lang.Double, Int)])) == 1) + assert(numberOfCheckedArguments( + deserializerForType( + ScalaReflection.localTypeOf[(java.lang.Integer, java.lang.Integer)])) == 0) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f0d61de97ffc..e9b100b3b30d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes test("nullable of encoder serializer") { def checkNullable[T: Encoder](nullable: Boolean): Unit = { - assert(encoderFor[T].serializer.forall(_.nullable === nullable)) + assert(encoderFor[T].objSerializer.nullable === nullable) } // test for flat encoders diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee96..25877aa9d389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1087,7 +1087,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.exprEnc.flat) { + val combined = if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1097,7 +1097,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.exprEnc.flat) { + val combined = if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1110,14 +1110,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.exprEnc.flat) { + if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.exprEnc.flat) { + if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1390,7 +1390,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (encoder.flat) { + if (!encoder.objSerializer.dataType.isInstanceOf[StructType]) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cb..b0232c0703b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * :: Experimental :: @@ -457,7 +458,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (kExprEnc.flat) { + + val keyColumn = if (!kExprEnc.objSerializer.dataType.isInstanceOf[StructType]) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 1cca43e77785..d1881de50257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,103 +19,35 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedDeserializer} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance, WrapOption} +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object TypedAggregateExpression { - // Checks if given encoder is for `Option[Product]`. - def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = { - // For all Option[_] classes, only Option[Product] is reported as not flat. - encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat - } - - /** - * Flattens serializers and deserializer of given encoder. We only flatten encoder - * of `Option[Product]` class. - */ - def flattenOptProductEncoder[T](encoder: ExpressionEncoder[T]): ExpressionEncoder[T] = { - val serializer = encoder.serializer - val deserializer = encoder.deserializer - - // This is just a sanity check. Encoders of Option[Product] has only one `CreateNamedStruct` - // serializer expression. - assert(serializer.length == 1, - "We only flatten encoder of Option[Product] class which has single serializer.") - - val flattenSerializers = serializer(0).collect { - case c: CreateNamedStruct => c.flatten - }.head - - val flattenDeserializer = deserializer match { - case w @ WrapOption(If(_, _, child: NewInstance), optType) => - val newInstance = child match { - case oldNewInstance: NewInstance => - val newArguments = oldNewInstance.arguments.zipWithIndex.map { case (arg, idx) => - arg match { - case a @ AssertNotNull( - UpCast(GetStructField( - child @ GetColumnByOrdinal(0, _), _, _), dt, walkedTypePath), _) => - a.copy(child = UpCast(GetColumnByOrdinal(idx, dt), dt, walkedTypePath.tail)) - } - } - oldNewInstance.copy(arguments = newArguments) - } - w.copy(child = newInstance) - case _ => - throw new AnalysisException( - "On top of deserializer of Option[Product] should be `WrapOption`.") - } - - val newSchema = encoder.schema.fields(0) - .dataType.asInstanceOf[StructType] - encoder.copy(serializer = flattenSerializers, deserializer = flattenDeserializer, - schema = newSchema) - } - def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { - val rawBufferEncoder = encoderFor[BUF] - - // When `BUF` or `OUT` is an Option[Product], we need to flatten serializers and deserializer - // of original encoder. It is because we wrap serializers of Option[Product] inside an extra - // struct in order to support encoding of Option[Product] at top-level row. But here we use - // the encoder to encode Option[Product] for a column, we need to get rid of this extra - // struct. - val bufferEncoder = if (isOptProductEncoder(rawBufferEncoder)) { - flattenOptProductEncoder(rawBufferEncoder) - } else { - rawBufferEncoder - } + val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val rawOutputEncoder = encoderFor[OUT] - val outputEncoder = if (isOptProductEncoder(rawOutputEncoder)) { - flattenOptProductEncoder(rawOutputEncoder) - } else { - rawOutputEncoder - } - val outputType = if (outputEncoder.flat) { - outputEncoder.schema.head.dataType - } else { - outputEncoder.schema - } + val outputEncoder = encoderFor[OUT] + val outputType = outputEncoder.objSerializer.dataType // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer // expression is an alias of `BoundReference`, which means the buffer object doesn't need // serialization. val isSimpleBuffer = { bufferSerializer.head match { - case Alias(_: BoundReference, _) if bufferEncoder.flat => true + case Alias(_: BoundReference, _) + if !bufferEncoder.objSerializer.dataType.isInstanceOf[StructType] => true case _ => false } } @@ -137,7 +69,7 @@ object TypedAggregateExpression { outputEncoder.serializer, outputEncoder.deserializer.dataType, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } else { ComplexTypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], @@ -148,7 +80,7 @@ object TypedAggregateExpression { bufferEncoder.resolveAndBind().deserializer, outputEncoder.serializer, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index 0a16a3819b77..86484b0008dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -215,7 +215,7 @@ object FlatMapGroupsWithStateExecHelper { override val stateSerializerExprs: Seq[Expression] = { val boundRefToSpecificInternalRow = BoundReference( - 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + 0, stateEncoder.serializer.collect { case b: BoundReference => b.dataType }.head, true) val nestedStateSerExpr = CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0446bd9097b6..538ea3c66c40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -149,7 +149,6 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) -case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -184,43 +183,6 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } -case class OptionBooleanIntAggregator(colName: String) - extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { - - override def zero: Option[(Boolean, Int)] = None - - override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { - val index = row.fieldIndex(colName) - val value = if (row.isNullAt(index)) { - Option.empty[(Boolean, Int)] - } else { - val nestedRow = row.getStruct(index) - Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) - } - merge(buffer, value) - } - - override def merge( - b1: Option[(Boolean, Int)], - b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { - if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { - val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) - Some((true, newInt)) - } else if (b1.isDefined) { - b1 - } else { - b2 - } - } - - override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction - - override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder - override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder - - def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() -} - class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -431,18 +393,4 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } - - test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { - val df = Seq( - OptionBooleanIntData("bob", Some((true, 1))), - OptionBooleanIntData("bob", Some((false, 2))), - OptionBooleanIntData("bob", None)).toDF() - - val group = df - .groupBy("name") - .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) - assert(df.schema == group.schema) - checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) - checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 1b1347738db7..4e593ff046a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1302,6 +1302,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } + test("SPARK-18251: the type of Dataset can't be Option of Product type") { + checkDataset(Seq(Some(1), None).toDS(), Some(1), None) + + val e = intercept[UnsupportedOperationException] { + Seq(Some(1 -> "a"), None).toDS() + } + assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) + } + test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1538,41 +1547,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } - - test("SPARK-24762: Enable top-level Option of Product encoders") { - val data = Seq(Some((1, "a")), Some((2, "b")), None) - val ds = data.toDS() - - checkDataset( - ds, - data: _*) - - val schema = StructType(Seq( - StructField("value", StructType(Seq( - StructField("_1", IntegerType, nullable = false), - StructField("_2", StringType, nullable = true) - )), nullable = true) - )) - - assert(ds.schema == schema) - - val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) - val nestedDs = nestedOptData.toDS() - - checkDataset( - nestedDs, - nestedOptData: _*) - - val nestedSchema = StructType(Seq( - StructField("value", StructType(Seq( - StructField("_1", StructType(Seq( - StructField("_1", IntegerType, nullable = false), - StructField("_2", StringType, nullable = true)))), - StructField("_2", DoubleType, nullable = false) - )), nullable = true) - )) - assert(nestedDs.schema == nestedSchema) - } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala deleted file mode 100644 index f54557b1e0f5..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - - -class TypedAggregateExpressionSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private def testOptProductEncoder(encoder: ExpressionEncoder[_], expected: Boolean): Unit = { - assert(TypedAggregateExpression.isOptProductEncoder(encoder) == expected) - } - - test("check an encoder is for option of product") { - testOptProductEncoder(encoderFor[Int], false) - testOptProductEncoder(encoderFor[(Long, Long)], false) - testOptProductEncoder(encoderFor[Option[Int]], false) - testOptProductEncoder(encoderFor[Option[(Int, Long)]], true) - testOptProductEncoder(encoderFor[Option[SimpleCaseClass]], true) - } - - test("flatten encoders of option of product") { - // Option[Product] is encoded as a struct column in a row. - val optProductEncoder: ExpressionEncoder[Option[(Int, Long)]] = encoderFor[Option[(Int, Long)]] - val optProductSchema = StructType(StructField("value", StructType( - StructField("_1", IntegerType) :: StructField("_2", LongType) :: Nil)) :: Nil) - - assert(optProductEncoder.schema.length == 1) - assert(DataType.equalsIgnoreCaseAndNullability(optProductEncoder.schema, optProductSchema)) - - val flattenEncoder = TypedAggregateExpression.flattenOptProductEncoder(optProductEncoder) - .resolveAndBind() - assert(flattenEncoder.schema.length == 2) - assert(DataType.equalsIgnoreCaseAndNullability(flattenEncoder.schema, - optProductSchema.fields(0).dataType)) - - val row = flattenEncoder.toRow(Some((1, 2L))) - val expected = flattenEncoder.fromRow(row) - assert(Some((1, 2L)) == expected) - } -} - -case class SimpleCaseClass(a: Int) From 6a6fa454e22728cc2ad8e5515cd587fe0be84b26 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 Oct 2018 02:07:40 +0000 Subject: [PATCH 09/21] Fix Malformed class name. --- .../apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 938b5948406a..b6f9f555b2ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -207,7 +207,7 @@ case class ExpressionEncoder[T]( */ val serializer: Seq[NamedExpression] = { val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] - val clsName = clsTag.runtimeClass.getCanonicalName + val clsName = Utils.getSimpleName(clsTag.runtimeClass) if (serializedAsStruct) { val nullSafeSerializer = objSerializer.transformUp { From 25a616286075ca4f0a7d528095b387172b05c6c3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 Oct 2018 05:11:10 +0000 Subject: [PATCH 10/21] Fix error message. --- .../apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b6f9f555b2ec..e7aa60601f62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -214,7 +214,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, Seq("top level Product input object")) + AssertNotNull(r, Seq("top level Product or row object")) } nullSafeSerializer match { case If(_, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 8d89f9c6c41d..93f9fc12e7f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val encoder = RowEncoder(schema) val e = intercept[RuntimeException](encoder.toRow(null)) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level row object")) + assert(e.getMessage.contains("top level Product or row object")) } test("RowEncoder should validate external type") { From 295ecde8103c26dda169d931f939f8a2fe641c4c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 18 Oct 2018 15:58:03 +0000 Subject: [PATCH 11/21] Fix test. --- .../streaming/state/FlatMapGroupsWithStateExecHelper.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index 86484b0008dd..0a16a3819b77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -215,7 +215,7 @@ object FlatMapGroupsWithStateExecHelper { override val stateSerializerExprs: Seq[Expression] = { val boundRefToSpecificInternalRow = BoundReference( - 0, stateEncoder.serializer.collect { case b: BoundReference => b.dataType }.head, true) + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) val nestedStateSerExpr = CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4e593ff046a5..27b3b3d78d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level Product input object")) + assert(e.getMessage.contains("top level Product or row object")) } test("dropDuplicates") { From 35700f4a0f36fb397ac028a68011a2753c5c2c75 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Oct 2018 00:07:29 +0000 Subject: [PATCH 12/21] Fix rebase error. --- .../scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a19d23004714..040b2112507f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -269,7 +269,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) UnresolvedMapObjects( - p => deserializerFor(et, Some(p)), + p => deserializerFor(et, p), path, customCollectionCls = Some(c)) From b211ed069dceb33c45cf6caf12c19527334d4ad8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Oct 2018 00:16:24 +0000 Subject: [PATCH 13/21] Fix unintentional style change. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6491e7dc0649..8c8026713f61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -192,10 +192,9 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal( - path: Expression, - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { val newPath = GetStructField(path, ordinal) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -385,7 +384,7 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - addToPathOrdinal(path, i, dataType, newTypePath), + addToPathOrdinal(i, dataType, newTypePath), newTypePath) } else { deserializerFor( From 0c78b73e5abce2a51763c860e43aab214c8634d9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Oct 2018 00:51:52 +0000 Subject: [PATCH 14/21] Address comments. --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index e7aa60601f62..b3550a1f98a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -88,9 +88,7 @@ object ExpressionEncoder { * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - if (encoders.length > 22) { - throw new RuntimeException("Can't construct a tuple encoder for more than 22 encoders.") - } + // TODO: check if encoders length is more than 22 and throw exception for it. encoders.foreach(_.assertUnresolved()) @@ -114,7 +112,7 @@ object ExpressionEncoder { returnNullable = originalInputObject.nullable) val newSerializer = enc.objSerializer.transformUp { - case b: BoundReference if b == originalInputObject => newInputObject + case b: BoundReference => newInputObject } Alias(newSerializer, s"_${index + 1}")() @@ -200,7 +198,7 @@ case class ExpressionEncoder[T]( extends Encoder[T] { /** - * A set of expressions, one for each top-level field that can be used to + * A sequence of expressions, one for each top-level field that can be used to * extract the values from a raw object into an [[InternalRow]]: * 1. If `serializer` encodes a raw object to a struct, we directly use the `serializer`. * 2. For other cases, we create a struct to wrap the `serializer`. @@ -217,7 +215,7 @@ case class ExpressionEncoder[T]( AssertNotNull(r, Seq("top level Product or row object")) } nullSafeSerializer match { - case If(_, _, s: CreateNamedStruct) => s + case If(_: IsNull, _, s: CreateNamedStruct) => s case s: CreateNamedStruct => s case _ => throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") From 5b9abb67907dfdb0c0c64751db3525564f832422 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 20 Oct 2018 02:26:07 +0000 Subject: [PATCH 15/21] Address ComplexTypeMergingExpression issue. --- .../sql/catalyst/encoders/RowEncoder.scala | 6 +---- .../sql/catalyst/expressions/Expression.scala | 22 +++++++++++++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index b65fc2252807..f6838bdab9e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -59,7 +59,7 @@ object RowEncoder { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val serializer = serializerFor(inputObject, schema) - val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) + val deserializer = deserializerFor(GetColumnByOrdinal(0, schema), schema) new ExpressionEncoder[Row]( serializer, deserializer, @@ -235,10 +235,6 @@ object RowEncoder { private def deserializerFor(input: Expression, schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val dt = f.dataType match { - case p: PythonUserDefinedType => p.sqlType - case other => other - } deserializerFor(GetStructField(input, i)) } CreateExternalRow(fields, schema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c215735ab1c9..e4ff53f507a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,19 +709,37 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + // If there is `PythonUserDefinedType`, `TypeCoercion.haveSameType` checks will fail. + // This method converts Python UDT to its underlying storage data type. + private def convertPythonUserDefinedType(dt: DataType): DataType = dt match { + case u: PythonUserDefinedType => u.sqlType + case ArrayType(et, containsNull) => ArrayType(convertPythonUserDefinedType(et), containsNull) + case MapType(kt, vt, valueContainsNull) => + MapType(convertPythonUserDefinedType(kt), convertPythonUserDefinedType(vt), valueContainsNull) + case s: StructType => + val fields = s.map { field => + StructField(field.name, convertPythonUserDefinedType(field.dataType), field.nullable, + field.metadata) + } + StructType(fields) + case o => o + } + + private lazy val actualInputTypes = inputTypesForMerging.map(convertPythonUserDefinedType(_)) + def dataTypeCheck: Unit = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") require( - TypeCoercion.haveSameType(inputTypesForMerging), + TypeCoercion.haveSameType(actualInputTypes), "All input types must be the same except nullable, containsNull, valueContainsNull flags." + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") } override def dataType: DataType = { dataTypeCheck - inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + actualInputTypes.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } From 7432344143fb4889ed3d5cbde21872c8fdd6d3f1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 20 Oct 2018 12:47:37 +0000 Subject: [PATCH 16/21] Try more reasonable solution. --- .../sql/catalyst/encoders/RowEncoder.scala | 6 ++--- .../sql/catalyst/expressions/Expression.scala | 22 ++----------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index f6838bdab9e7..b184b34a6ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -169,7 +169,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput @@ -185,7 +185,7 @@ object RowEncoder { val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), - Literal.create(null, field.dataType), + Literal.create(null, fieldValue.dataType), fieldValue ) } else { @@ -196,7 +196,7 @@ object RowEncoder { if (inputObject.nullable) { If(IsNull(inputObject), - Literal.create(null, inputType), + Literal.create(null, nonNullOutput.dataType), nonNullOutput) } else { nonNullOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e4ff53f507a7..c215735ab1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,37 +709,19 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - // If there is `PythonUserDefinedType`, `TypeCoercion.haveSameType` checks will fail. - // This method converts Python UDT to its underlying storage data type. - private def convertPythonUserDefinedType(dt: DataType): DataType = dt match { - case u: PythonUserDefinedType => u.sqlType - case ArrayType(et, containsNull) => ArrayType(convertPythonUserDefinedType(et), containsNull) - case MapType(kt, vt, valueContainsNull) => - MapType(convertPythonUserDefinedType(kt), convertPythonUserDefinedType(vt), valueContainsNull) - case s: StructType => - val fields = s.map { field => - StructField(field.name, convertPythonUserDefinedType(field.dataType), field.nullable, - field.metadata) - } - StructType(fields) - case o => o - } - - private lazy val actualInputTypes = inputTypesForMerging.map(convertPythonUserDefinedType(_)) - def dataTypeCheck: Unit = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") require( - TypeCoercion.haveSameType(actualInputTypes), + TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags." + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") } override def dataType: DataType = { dataTypeCheck - actualInputTypes.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } From 400f87817183640006140e2db1839f8d78a13856 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 22 Oct 2018 10:56:20 +0800 Subject: [PATCH 17/21] Address comment. --- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b3550a1f98a2..61693a4de8a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{`Type`, RuntimeClass, TypeTag} +import scala.reflect.runtime.universe.{`Type`, typeTag, RuntimeClass, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} @@ -47,7 +47,7 @@ object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { val mirror = ScalaReflection.mirror - val tpe = ScalaReflection.localTypeOf[T] + val tpe = typeTag[T].in(mirror).tpe val cls = mirror.runtimeClass(tpe) if (ScalaReflection.optionOfProductType(tpe)) { From 8cb710b5c7b329468c320b59bb0625866fd8d836 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Oct 2018 16:45:37 +0800 Subject: [PATCH 18/21] Address comments. --- .../spark/sql/catalyst/ScalaReflection.scala | 19 +++----------- .../catalyst/encoders/ExpressionEncoder.scala | 18 +++++++------ .../sql/catalyst/ScalaReflectionSuite.scala | 26 +++++++------------ .../scala/org/apache/spark/sql/Dataset.scala | 10 +++---- .../spark/sql/KeyValueGroupedDataset.scala | 4 +-- .../aggregate/TypedAggregateExpression.scala | 4 +-- 6 files changed, 30 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f16c1130be8..40074b36f6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -24,7 +24,7 @@ import scala.util.Properties import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} @@ -426,14 +426,13 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerForType(tpe: `Type`, - cls: RuntimeClass): Expression = ScalaReflection.cleanUpReflectionObjects { + def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil // The input object to `ExpressionEncoder` is located at first column of an row. val inputObject = BoundReference(0, dataTypeFor(tpe), - nullable = !cls.isPrimitive) + nullable = !tpe.typeSymbol.asClass.isPrimitive) serializerFor(inputObject, tpe, walkedTypePath) } @@ -441,18 +440,6 @@ object ScalaReflection extends ScalaReflection { /** * Returns an expression for serializing the value of an input expression into Spark SQL * internal representation. - * - * The expression generated by this method will be used by `ExpressionEncoder` as serializer - * to convert a JVM object to Spark SQL representation. - * - * The returned serializer generally converts a JVM object to corresponding Spark SQL - * representation. For example, `Seq[_]` is converted to a Spark SQL array, `Product` is - * converted to a Spark SQL struct. - * - * If input object is not of ObjectType, it means that the input object is already in a form - * of Spark's internal representation. We simply return the input object. - * - * For unsupported types, an `UnsupportedOperationException` will be thrown. */ private def serializerFor( inputObject: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 61693a4de8a6..4845b569683f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{`Type`, typeTag, RuntimeClass, TypeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} @@ -48,7 +48,6 @@ object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - val cls = mirror.runtimeClass(tpe) if (ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( @@ -59,7 +58,8 @@ object ExpressionEncoder { "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") } - val serializer = ScalaReflection.serializerForType(tpe, cls) + val cls = mirror.runtimeClass(tpe) + val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) new ExpressionEncoder[T]( @@ -204,10 +204,9 @@ case class ExpressionEncoder[T]( * 2. For other cases, we create a struct to wrap the `serializer`. */ val serializer: Seq[NamedExpression] = { - val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (serializedAsStruct) { + if (isSerializedAsStruct) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -236,9 +235,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] - - if (serializedAsStruct) { + if (isSerializedAsStruct) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -264,6 +261,11 @@ case class ExpressionEncoder[T]( StructField(s.name, s.dataType, s.nullable) }) + /** + * Returns true if the type `T` is serialized as a struct. + */ + def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 93704866f079..17280add5020 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -262,8 +262,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) - val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]], - classOf[List[Int]]) + val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]]) assert(serializer.isInstanceOf[NewInstance]) assert(serializer.asInstanceOf[NewInstance] .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) @@ -276,16 +275,14 @@ class ScalaReflectionSuite extends SparkFunSuite { test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue - val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]], - classOf[Queue[Int]]) + val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]]) assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val queueDeserializer = deserializerForType(ScalaReflection.localTypeOf[Queue[Int]]) assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer - val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]], - classOf[ArrayBuffer[Int]]) + val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) @@ -293,16 +290,14 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("serialize and deserialize arbitrary map types") { - val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]], - classOf[Map[Int, Int]]) + val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val mapDeserializer = deserializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap - val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]], - classOf[HashMap[Int, Int]]) + val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val hashMapDeserializer = deserializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) @@ -310,8 +305,7 @@ class ScalaReflectionSuite extends SparkFunSuite { import scala.collection.mutable.{LinkedHashMap => LHMap} val linkedHashMapSerializer = serializerForType( - ScalaReflection.localTypeOf[LHMap[Long, String]], - classOf[LHMap[Long, String]]) + ScalaReflection.localTypeOf[LHMap[Long, String]]) assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerForType( @@ -320,10 +314,10 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22442: Generate correct field names for special characters") { - val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData], - classOf[SpecialCharAsFieldData]).collect { - case If(_, _, s: CreateNamedStruct) => s - }.head + val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData]) + .collect { + case If(_, _, s: CreateNamedStruct) => s + }.head val deserializer = deserializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData]) assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1fd8d01ad1f5..c91b0d778fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1087,7 +1087,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { + val combined = if (!this.exprEnc.isSerializedAsStruct) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1097,7 +1097,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { + val combined = if (!other.exprEnc.isSerializedAsStruct) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1110,14 +1110,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { + if (!this.exprEnc.isSerializedAsStruct) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.objSerializer.dataType.isInstanceOf[StructType]) { + if (!other.exprEnc.isSerializedAsStruct) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1390,7 +1390,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.objSerializer.dataType.isInstanceOf[StructType]) { + if (!encoder.isSerializedAsStruct) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index b0232c0703b6..555bcdffb6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.StructType /** * :: Experimental :: @@ -458,8 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - - val keyColumn = if (!kExprEnc.objSerializer.dataType.isInstanceOf[StructType]) { + val keyColumn = if (!kExprEnc.isSerializedAsStruct) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index d1881de50257..39200ec00e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object TypedAggregateExpression { - def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] @@ -46,8 +45,7 @@ object TypedAggregateExpression { // serialization. val isSimpleBuffer = { bufferSerializer.head match { - case Alias(_: BoundReference, _) - if !bufferEncoder.objSerializer.dataType.isInstanceOf[StructType] => true + case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true case _ => false } } From 682fa4b2b5638d88be01854d5ae41bbd1eb54eee Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Oct 2018 19:01:25 +0800 Subject: [PATCH 19/21] Make comment more precise. --- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 4845b569683f..29f6136a75ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -200,8 +200,9 @@ case class ExpressionEncoder[T]( /** * A sequence of expressions, one for each top-level field that can be used to * extract the values from a raw object into an [[InternalRow]]: - * 1. If `serializer` encodes a raw object to a struct, we directly use the `serializer`. - * 2. For other cases, we create a struct to wrap the `serializer`. + * 1. If `serializer` encodes a raw object to a struct, strip the outer If-IsNull and get + * the `CreateNamedStruct`. + * 2. For other cases, wrap the single serializer with `CreateNamedStruct`. */ val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) From 078a071a72e0d39cece49ff73c09ec65a387b8af Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Oct 2018 19:21:35 +0800 Subject: [PATCH 20/21] Simplify test change. --- .../sql/catalyst/ScalaReflectionSuite.scala | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 17280add5020..dbf39a9abac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} @@ -111,6 +113,10 @@ object TestingUDT { class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + // A helper method used to test `ScalaReflection.deserializerForType`. + private def deserializerFor[T: TypeTag]: Expression = + deserializerForType(ScalaReflection.localTypeOf[T]) + test("SQLUserDefinedType annotation on Scala structure") { val schema = schemaFor[TestingUDT.NestedStruct] assert(schema === Schema( @@ -269,7 +275,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerForType(ScalaReflection.localTypeOf[List[Int]]) + val listDeserializer = deserializerFor[List[Int]] assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } @@ -278,14 +284,14 @@ class ScalaReflectionSuite extends SparkFunSuite { val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]]) assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerForType(ScalaReflection.localTypeOf[Queue[Int]]) + val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } @@ -293,14 +299,14 @@ class ScalaReflectionSuite extends SparkFunSuite { val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) + val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} @@ -308,8 +314,7 @@ class ScalaReflectionSuite extends SparkFunSuite { ScalaReflection.localTypeOf[LHMap[Long, String]]) assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerForType( - ScalaReflection.localTypeOf[LHMap[Long, String]]) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } @@ -318,7 +323,7 @@ class ScalaReflectionSuite extends SparkFunSuite { .collect { case If(_, _, s: CreateNamedStruct) => s }.head - val deserializer = deserializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData]) + val deserializer = deserializerFor[SpecialCharAsFieldData] assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -332,8 +337,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerForType(ScalaReflection.localTypeOf[Int]).isInstanceOf[AssertNotNull]) - assert(!deserializerForType(ScalaReflection.localTypeOf[String]).isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -350,12 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite { val newInstance = deserializer.collect { case n: NewInstance => n}.head newInstance.arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments( - deserializerForType(ScalaReflection.localTypeOf[(Double, Double)])) == 2) - assert(numberOfCheckedArguments( - deserializerForType(ScalaReflection.localTypeOf[(java.lang.Double, Int)])) == 1) - assert(numberOfCheckedArguments( - deserializerForType( - ScalaReflection.localTypeOf[(java.lang.Integer, java.lang.Integer)])) == 0) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } } From c00d5e44a21f8053a97db755f7a705872d4121eb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Oct 2018 07:45:51 +0800 Subject: [PATCH 21/21] Address comment. --- .../sql/catalyst/ScalaReflectionSuite.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index dbf39a9abac4..d98589db323c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -113,6 +113,10 @@ object TestingUDT { class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + // A helper method used to test `ScalaReflection.serializerForType`. + private def serializerFor[T: TypeTag]: Expression = + serializerForType(ScalaReflection.localTypeOf[T]) + // A helper method used to test `ScalaReflection.deserializerForType`. private def deserializerFor[T: TypeTag]: Expression = deserializerForType(ScalaReflection.localTypeOf[T]) @@ -268,7 +272,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) - val serializer = serializerForType(ScalaReflection.localTypeOf[List[Int]]) + val serializer = serializerFor[List[Int]] assert(serializer.isInstanceOf[NewInstance]) assert(serializer.asInstanceOf[NewInstance] .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) @@ -281,14 +285,14 @@ class ScalaReflectionSuite extends SparkFunSuite { test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue - val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]]) + val queueSerializer = serializerFor[Queue[Int]] assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer - val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]]) + val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]] assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] @@ -296,22 +300,21 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("serialize and deserialize arbitrary map types") { - val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]]) + val mapSerializer = serializerFor[Map[Int, Int]] assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap - val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]]) + val hashMapSerializer = serializerFor[HashMap[Int, Int]] assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} - val linkedHashMapSerializer = serializerForType( - ScalaReflection.localTypeOf[LHMap[Long, String]]) + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]] assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] @@ -319,7 +322,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22442: Generate correct field names for special characters") { - val serializer = serializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData]) + val serializer = serializerFor[SpecialCharAsFieldData] .collect { case If(_, _, s: CreateNamedStruct) => s }.head