diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 592520c59a76..d019924711e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -49,15 +49,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) @@ -198,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -213,6 +204,9 @@ case class ExpressionEncoder[T]( } else { // For other input objects like primitive, array, map, etc., we construct a struct to wrap // the serializer which is a column of an row. + // + // Note: Because Spark SQL doesn't allow top-level row to be null, to encode + // top-level Option[Product] type, we make it as a top-level struct column. CreateNamedStruct(Literal("value") :: objSerializer :: Nil) } }.flatten @@ -226,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -253,10 +247,24 @@ case class ExpressionEncoder[T]( }) /** - * Returns true if the type `T` is serialized as a struct. + * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + /** + * Returns true if the type `T` is an `Option` type. + */ + def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + + /** + * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in + * the struct are naturally mapped to top-level columns in a row. In other words, the serialized + * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be + * flattened to top-level row, because in Spark SQL top-level row can't be null. This method + * returns true if `T` is serialized as struct and is not `Option` type. + */ + def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e77ec040625..c78011485479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1084,7 +1084,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct) { + val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1094,7 +1094,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct) { + val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1107,14 +1107,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct) { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct) { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1387,7 +1387,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct) { + if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 7a47242f6938..dbb1c313869f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct) { + val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 39200ec00e15..b75752945a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -40,9 +40,9 @@ object TypedAggregateExpression { val outputEncoder = encoderFor[OUT] val outputType = outputEncoder.objSerializer.dataType - // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer - // expression is an alias of `BoundReference`, which means the buffer object doesn't need - // serialization. + // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct + // and the serializer expression is an alias of `BoundReference`, which means the buffer + // object doesn't need serialization. val isSimpleBuffer = { bufferSerializer.head match { case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true @@ -76,7 +76,7 @@ object TypedAggregateExpression { None, bufferSerializer, bufferEncoder.resolveAndBind().deserializer, - outputEncoder.serializer, + outputEncoder.objSerializer, outputType, outputEncoder.objSerializer.nullable) } @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, - outputSerializer: Seq[Expression], + outputSerializer: Expression, dataType: DataType, nullable: Boolean, mutableAggBufferOffset: Int = 0, @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( aggregator.merge(buffer, input) } - private lazy val resultObjToRow = dataType match { - case _: StructType => - UnsafeProjection.create(CreateStruct(outputSerializer)) - case _ => - assert(outputSerializer.length == 1) - UnsafeProjection.create(outputSerializer.head) - } + private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) override def eval(buffer: Any): Any = { val resultObj = aggregator.finish(buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40..97c3f358c0e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + + val expectedSchema = new StructType() + .add("name", StringType, nullable = true) + .add("isGood", + new StructType() + .add("_1", BooleanType, nullable = false) + .add("_2", IntegerType, nullable = false), + nullable = true) + + assert(df.schema == expectedSchema) + assert(group.schema == expectedSchema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ac677e8ec6bc..624b15f5e98d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1311,15 +1311,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1557,6 +1548,74 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(Row("Amsterdam"))) } + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) + + assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) + } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS() + .as[(Int, Option[(String, Double)])] + checkDataset(ds, + (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, + Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, + None, None, Some((3, 4))) + } + + test("SPARK-24762: joinWith on Option[Product]") { + val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") + checkDataset(joined, (Some((2, 3)), Some((1, 2)))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert(ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } + test("SPARK-25942: typed aggregation on primitive type") { val ds = Seq(1, 2, 3).toDS()