Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ object ExpressionEncoder {
val mirror = ScalaReflection.mirror
val tpe = typeTag[T].in(mirror).tpe

if (ScalaReflection.optionOfProductType(tpe)) {
throw new UnsupportedOperationException(
"Cannot create encoder for Option of Product type, because Product type is represented " +
"as a row, and the entire row can not be null in Spark SQL like normal databases. " +
"You can wrap your type with Tuple1 if you do want top level null Product objects, " +
"e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " +
"`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
}

val cls = mirror.runtimeClass(tpe)
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
Expand Down Expand Up @@ -198,7 +189,7 @@ case class ExpressionEncoder[T](
val serializer: Seq[NamedExpression] = {
val clsName = Utils.getSimpleName(clsTag.runtimeClass)

if (isSerializedAsStruct) {
if (isSerializedAsStructForTopLevel) {
val nullSafeSerializer = objSerializer.transformUp {
case r: BoundReference =>
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
Expand All @@ -213,6 +204,9 @@ case class ExpressionEncoder[T](
} else {
// For other input objects like primitive, array, map, etc., we construct a struct to wrap
// the serializer which is a column of an row.
//
// Note: Because Spark SQL doesn't allow top-level row to be null, to encode
// top-level Option[Product] type, we make it as a top-level struct column.
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
}
}.flatten
Expand All @@ -226,7 +220,7 @@ case class ExpressionEncoder[T](
* `GetColumnByOrdinal` with corresponding ordinal.
*/
val deserializer: Expression = {
if (isSerializedAsStruct) {
if (isSerializedAsStructForTopLevel) {
// We serialized this kind of objects to root-level row. The input of general deserializer
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
// transform attributes accessors.
Expand All @@ -253,10 +247,24 @@ case class ExpressionEncoder[T](
})

/**
* Returns true if the type `T` is serialized as a struct.
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]

/**
* Returns true if the type `T` is an `Option` type.
*/
def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)

/**
* If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in
* the struct are naturally mapped to top-level columns in a row. In other words, the serialized
* struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be
* flattened to top-level row, because in Spark SQL top-level row can't be null. This method
* returns true if `T` is serialized as struct and is not `Option` type.
*/
def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you send a followup PR to inline isOptionType if it's only used here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.


// serializer expressions are used to encode an object to a row, while the object is usually an
// intermediate value produced inside an operator, not from the output of the child operator. This
// is quite different from normal expressions, and `AttributeReference` doesn't work here
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ class Dataset[T] private[sql](
// Note that we do this before joining them, to enable the join operator to return null for one
// side, in cases like outer-join.
val left = {
val combined = if (!this.exprEnc.isSerializedAsStruct) {
val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
Expand All @@ -1094,7 +1094,7 @@ class Dataset[T] private[sql](
}

val right = {
val combined = if (!other.exprEnc.isSerializedAsStruct) {
val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
Expand All @@ -1107,14 +1107,14 @@ class Dataset[T] private[sql](
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
if (!this.exprEnc.isSerializedAsStruct) {
if (!this.exprEnc.isSerializedAsStructForTopLevel) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
if (!other.exprEnc.isSerializedAsStruct) {
if (!other.exprEnc.isSerializedAsStructForTopLevel) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
Expand Down Expand Up @@ -1387,7 +1387,7 @@ class Dataset[T] private[sql](
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)

if (!encoder.isSerializedAsStruct) {
if (!encoder.isSerializedAsStructForTopLevel) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ object TypedAggregateExpression {
val outputEncoder = encoderFor[OUT]
val outputType = outputEncoder.objSerializer.dataType

// Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
// expression is an alias of `BoundReference`, which means the buffer object doesn't need
// serialization.
// Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct
// and the serializer expression is an alias of `BoundReference`, which means the buffer
// object doesn't need serialization.
val isSimpleBuffer = {
bufferSerializer.head match {
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
Expand Down Expand Up @@ -76,7 +76,7 @@ object TypedAggregateExpression {
None,
bufferSerializer,
bufferEncoder.resolveAndBind().deserializer,
outputEncoder.serializer,
outputEncoder.objSerializer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm, this is a un-related change and just clean up the code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required. Without this, the output schema of aggregator using Option[Product] as output encoder is not correct.

outputType,
outputEncoder.objSerializer.nullable)
}
Expand Down Expand Up @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression(
inputSchema: Option[StructType],
bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
outputSerializer: Seq[Expression],
outputSerializer: Expression,
dataType: DataType,
nullable: Boolean,
mutableAggBufferOffset: Int = 0,
Expand Down Expand Up @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression(
aggregator.merge(buffer, input)
}

private lazy val resultObjToRow = dataType match {
case _: StructType =>
UnsafeProjection.create(CreateStruct(outputSerializer))
case _ =>
assert(outputSerializer.length == 1)
UnsafeProjection.create(outputSerializer.head)
}
private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer)

override def eval(buffer: Any): Any = {
val resultObj = aggregator.finish(buffer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType}


object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
Expand Down Expand Up @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] {


case class OptionBooleanData(name: String, isGood: Option[Boolean])
case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)])

case class OptionBooleanAggregator(colName: String)
extends Aggregator[Row, Option[Boolean], Option[Boolean]] {
Expand Down Expand Up @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String)
def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder()
}

case class OptionBooleanIntAggregator(colName: String)
extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the expected schema after we apply an aggregator with Option[Product] as buffer/output?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a non top-level encoder, the output schema of Option[Product] should be struct column.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming non top level, Option[Product] is same as Product?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For non top level, [Option[Product] is same as Product. The difference is additional WrapOption and UnwrapOption around expressions.


override def zero: Option[(Boolean, Int)] = None

override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = {
val index = row.fieldIndex(colName)
val value = if (row.isNullAt(index)) {
Option.empty[(Boolean, Int)]
} else {
val nestedRow = row.getStruct(index)
Some((nestedRow.getBoolean(0), nestedRow.getInt(1)))
}
merge(buffer, value)
}

override def merge(
b1: Option[(Boolean, Int)],
b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = {
if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) {
val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0)
Some((true, newInt))
} else if (b1.isDefined) {
b1
} else {
b2
}
}

override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction

override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder
override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder

def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder()
}

class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._

Expand Down Expand Up @@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
assert(grouped.schema == df.schema)
checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true)))
}

test("SPARK-24762: Aggregator should be able to use Option of Product encoder") {
val df = Seq(
OptionBooleanIntData("bob", Some((true, 1))),
OptionBooleanIntData("bob", Some((false, 2))),
OptionBooleanIntData("bob", None)).toDF()

val group = df
.groupBy("name")
.agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood"))

val expectedSchema = new StructType()
.add("name", StringType, nullable = true)
.add("isGood",
new StructType()
.add("_1", BooleanType, nullable = false)
.add("_2", IntegerType, nullable = false),
nullable = true)

assert(df.schema == expectedSchema)
assert(group.schema == expectedSchema)
checkAnswer(group, Row("bob", Row(true, 3)) :: Nil)
checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3))))
}
}
77 changes: 68 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1311,15 +1311,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(dsString, arrayString)
}

test("SPARK-18251: the type of Dataset can't be Option of Product type") {
checkDataset(Seq(Some(1), None).toDS(), Some(1), None)

val e = intercept[UnsupportedOperationException] {
Seq(Some(1 -> "a"), None).toDS()
}
assert(e.getMessage.contains("Cannot create encoder for Option of Product type"))
}

test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") {
// Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt
// instead of Int for avoiding possible overflow.
Expand Down Expand Up @@ -1557,6 +1548,74 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq(Row("Amsterdam")))
}

test("SPARK-24762: Enable top-level Option of Product encoders") {
val data = Seq(Some((1, "a")), Some((2, "b")), None)
val ds = data.toDS()

checkDataset(
ds,
data: _*)

val schema = new StructType().add(
"value",
new StructType()
.add("_1", IntegerType, nullable = false)
.add("_2", StringType, nullable = true),
nullable = true)

assert(ds.schema == schema)

val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0)))
val nestedDs = nestedOptData.toDS()

checkDataset(
nestedDs,
nestedOptData: _*)

val nestedSchema = StructType(Seq(
StructField("value", StructType(Seq(
StructField("_1", StructType(Seq(
StructField("_1", IntegerType, nullable = false),
StructField("_2", StringType, nullable = true)))),
StructField("_2", DoubleType, nullable = false)
)), nullable = true)
))
assert(nestedDs.schema == nestedSchema)
}

test("SPARK-24762: Resolving Option[Product] field") {
val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS()
.as[(Int, Option[(String, Double)])]
checkDataset(ds,
(1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None))
}

test("SPARK-24762: select Option[Product] field") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]])
checkDataset(ds1,
Some((1, 2)), Some((2, 3)), Some((3, 4)))

val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]])
checkDataset(ds2,
None, None, Some((3, 4)))
}

test("SPARK-24762: joinWith on Option[Product]") {
val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
checkDataset(joined, (Some((2, 3)), Some((1, 2))))
}

test("SPARK-24762: typed agg on Option[Product] type") {
val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS()
assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1)))

assert(ds.groupByKey(x => x).count().collect() ===
Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1)))
}

test("SPARK-25942: typed aggregation on primitive type") {
val ds = Seq(1, 2, 3).toDS()

Expand Down