-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24762][SQL] Enable Option of Product encoders #21732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e1b5dee
80506f4
ed3d5cb
9fc3f61
5f95bd0
a4f0405
c1f798f
80e11d2
0f029b0
0ffbf18
16af64c
79d10c1
fec1cac
e8737d4
3956cdd
8304de8
3a8a047
2d2057b
91d2b8b
dbd8678
62fdb17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,15 +49,6 @@ object ExpressionEncoder { | |
| val mirror = ScalaReflection.mirror | ||
| val tpe = typeTag[T].in(mirror).tpe | ||
|
|
||
| if (ScalaReflection.optionOfProductType(tpe)) { | ||
| throw new UnsupportedOperationException( | ||
| "Cannot create encoder for Option of Product type, because Product type is represented " + | ||
| "as a row, and the entire row can not be null in Spark SQL like normal databases. " + | ||
| "You can wrap your type with Tuple1 if you do want top level null Product objects, " + | ||
| "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + | ||
| "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") | ||
| } | ||
|
|
||
| val cls = mirror.runtimeClass(tpe) | ||
| val serializer = ScalaReflection.serializerForType(tpe) | ||
| val deserializer = ScalaReflection.deserializerForType(tpe) | ||
|
|
@@ -198,7 +189,7 @@ case class ExpressionEncoder[T]( | |
| val serializer: Seq[NamedExpression] = { | ||
| val clsName = Utils.getSimpleName(clsTag.runtimeClass) | ||
|
|
||
| if (isSerializedAsStruct) { | ||
| if (isSerializedAsStructForTopLevel) { | ||
| val nullSafeSerializer = objSerializer.transformUp { | ||
| case r: BoundReference => | ||
| // For input object of Product type, we can't encode it to row if it's null, as Spark SQL | ||
|
|
@@ -213,6 +204,9 @@ case class ExpressionEncoder[T]( | |
| } else { | ||
| // For other input objects like primitive, array, map, etc., we construct a struct to wrap | ||
| // the serializer which is a column of an row. | ||
| // | ||
| // Note: Because Spark SQL doesn't allow top-level row to be null, to encode | ||
| // top-level Option[Product] type, we make it as a top-level struct column. | ||
| CreateNamedStruct(Literal("value") :: objSerializer :: Nil) | ||
| } | ||
| }.flatten | ||
|
|
@@ -226,7 +220,7 @@ case class ExpressionEncoder[T]( | |
| * `GetColumnByOrdinal` with corresponding ordinal. | ||
| */ | ||
| val deserializer: Expression = { | ||
| if (isSerializedAsStruct) { | ||
| if (isSerializedAsStructForTopLevel) { | ||
| // We serialized this kind of objects to root-level row. The input of general deserializer | ||
| // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to | ||
| // transform attributes accessors. | ||
|
|
@@ -253,10 +247,24 @@ case class ExpressionEncoder[T]( | |
| }) | ||
|
|
||
| /** | ||
| * Returns true if the type `T` is serialized as a struct. | ||
| * Returns true if the type `T` is serialized as a struct by `objSerializer`. | ||
| */ | ||
| def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] | ||
|
|
||
| /** | ||
| * Returns true if the type `T` is an `Option` type. | ||
| */ | ||
| def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) | ||
|
|
||
| /** | ||
| * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in | ||
| * the struct are naturally mapped to top-level columns in a row. In other words, the serialized | ||
| * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be | ||
| * flattened to top-level row, because in Spark SQL top-level row can't be null. This method | ||
| * returns true if `T` is serialized as struct and is not `Option` type. | ||
| */ | ||
| def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you send a followup PR to inline
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. |
||
|
|
||
| // serializer expressions are used to encode an object to a row, while the object is usually an | ||
| // intermediate value produced inside an operator, not from the output of the child operator. This | ||
| // is quite different from normal expressions, and `AttributeReference` doesn't work here | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,9 +40,9 @@ object TypedAggregateExpression { | |
| val outputEncoder = encoderFor[OUT] | ||
| val outputType = outputEncoder.objSerializer.dataType | ||
|
|
||
| // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer | ||
| // expression is an alias of `BoundReference`, which means the buffer object doesn't need | ||
| // serialization. | ||
| // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct | ||
| // and the serializer expression is an alias of `BoundReference`, which means the buffer | ||
| // object doesn't need serialization. | ||
| val isSimpleBuffer = { | ||
| bufferSerializer.head match { | ||
| case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true | ||
|
|
@@ -76,7 +76,7 @@ object TypedAggregateExpression { | |
| None, | ||
| bufferSerializer, | ||
| bufferEncoder.resolveAndBind().deserializer, | ||
| outputEncoder.serializer, | ||
| outputEncoder.objSerializer, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to confirm, this is a un-related change and just clean up the code?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required. Without this, the output schema of aggregator using Option[Product] as output encoder is not correct. |
||
| outputType, | ||
| outputEncoder.objSerializer.nullable) | ||
| } | ||
|
|
@@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( | |
| inputSchema: Option[StructType], | ||
| bufferSerializer: Seq[NamedExpression], | ||
| bufferDeserializer: Expression, | ||
| outputSerializer: Seq[Expression], | ||
| outputSerializer: Expression, | ||
| dataType: DataType, | ||
| nullable: Boolean, | ||
| mutableAggBufferOffset: Int = 0, | ||
|
|
@@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( | |
| aggregator.merge(buffer, input) | ||
| } | ||
|
|
||
| private lazy val resultObjToRow = dataType match { | ||
| case _: StructType => | ||
| UnsafeProjection.create(CreateStruct(outputSerializer)) | ||
| case _ => | ||
| assert(outputSerializer.length == 1) | ||
| UnsafeProjection.create(outputSerializer.head) | ||
| } | ||
| private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) | ||
|
|
||
| override def eval(buffer: Any): Any = { | ||
| val resultObj = aggregator.finish(buffer) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator | |
| import org.apache.spark.sql.expressions.scalalang.typed | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
| import org.apache.spark.sql.types.StringType | ||
| import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} | ||
|
|
||
|
|
||
| object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { | ||
|
|
@@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { | |
|
|
||
|
|
||
| case class OptionBooleanData(name: String, isGood: Option[Boolean]) | ||
| case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) | ||
|
|
||
| case class OptionBooleanAggregator(colName: String) | ||
| extends Aggregator[Row, Option[Boolean], Option[Boolean]] { | ||
|
|
@@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) | |
| def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() | ||
| } | ||
|
|
||
| case class OptionBooleanIntAggregator(colName: String) | ||
| extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the expected schema after we apply an aggregator with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a non top-level encoder, the output schema of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assuming non top level,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. For non top level, |
||
|
|
||
| override def zero: Option[(Boolean, Int)] = None | ||
|
|
||
| override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { | ||
| val index = row.fieldIndex(colName) | ||
| val value = if (row.isNullAt(index)) { | ||
| Option.empty[(Boolean, Int)] | ||
| } else { | ||
| val nestedRow = row.getStruct(index) | ||
| Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) | ||
| } | ||
| merge(buffer, value) | ||
| } | ||
|
|
||
| override def merge( | ||
| b1: Option[(Boolean, Int)], | ||
| b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { | ||
| if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { | ||
| val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) | ||
| Some((true, newInt)) | ||
| } else if (b1.isDefined) { | ||
| b1 | ||
| } else { | ||
| b2 | ||
| } | ||
| } | ||
|
|
||
| override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction | ||
|
|
||
| override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder | ||
| override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder | ||
|
|
||
| def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() | ||
| } | ||
|
|
||
| class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { | ||
| import testImplicits._ | ||
|
|
||
|
|
@@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { | |
| assert(grouped.schema == df.schema) | ||
| checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) | ||
| } | ||
|
|
||
| test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { | ||
| val df = Seq( | ||
| OptionBooleanIntData("bob", Some((true, 1))), | ||
| OptionBooleanIntData("bob", Some((false, 2))), | ||
| OptionBooleanIntData("bob", None)).toDF() | ||
|
|
||
| val group = df | ||
| .groupBy("name") | ||
| .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) | ||
|
|
||
| val expectedSchema = new StructType() | ||
| .add("name", StringType, nullable = true) | ||
| .add("isGood", | ||
| new StructType() | ||
| .add("_1", BooleanType, nullable = false) | ||
| .add("_2", IntegerType, nullable = false), | ||
| nullable = true) | ||
|
|
||
| assert(df.schema == expectedSchema) | ||
| assert(group.schema == expectedSchema) | ||
| checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) | ||
| checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.