diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 2cf8312ea59aa..65ff66f6e7f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -647,15 +647,15 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. - var result: Any = null + var result: MutableLiteral = MutableLiteral(null, expr.dataType) override def update(input: InternalRow): Unit = { - if (result == null) { - result = expr.eval(input) + if (result.value == null) { + result.value = expr.eval(input) } } - override def eval(input: InternalRow): Any = result + override def eval(input: InternalRow): Any = result.value } case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { @@ -676,13 +676,11 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. - var result: Any = null + var result: MutableLiteral = MutableLiteral(null, expr.dataType) override def update(input: InternalRow): Unit = { - result = input + result.value = expr.eval(input) } - override def eval(input: InternalRow): Any = { - if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null - } + override def eval(input: InternalRow): Any = result.value }