diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 8bdfa48a30c9..2cdf4703a5d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -51,7 +51,8 @@ object TypedAggregateExpression { bufferDeserializer, outputEncoder.serializer, outputEncoder.deserializer.dataType, - outputType) + outputType, + !outputEncoder.flat || outputEncoder.schema.head.nullable) } } @@ -65,9 +66,8 @@ case class TypedAggregateExpression( bufferDeserializer: Expression, outputSerializer: Seq[Expression], outputExternalType: DataType, - dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { - - override def nullable: Boolean = true + dataType: DataType, + nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression { override def deterministic: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 32fcf84b02f9..ddc4dcd2395b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -305,4 +305,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds = Seq(1, 2, 3).toDS() checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1) } + + test("SPARK-15204 improve nullability inference for Aggregator") { + val ds1 = Seq(1, 3, 2, 5).toDS() + assert(ds1.select(typed.sum((i: Int) => i)).schema.head.nullable === false) + val ds2 = Seq(AggData(1, "a"), AggData(2, "a")).toDS() + assert(ds2.select(SeqAgg.toColumn).schema.head.nullable === true) + val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] + assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true) + } }