From 9662a2f7c3d58938557c76ee2fefc60b519d1955 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 11 Aug 2015 14:54:47 -0700 Subject: [PATCH 1/7] Change the default behavior of First/Last to RESPECT NULLS. --- .../catalyst/analysis/FunctionRegistry.scala | 2 ++ .../expressions/aggregate/functions.scala | 33 +++++++++++++++---- .../execution/AggregationQuerySuite.scala | 30 +++++++++++++++++ 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d788151..4b889d43c7470 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -165,7 +165,9 @@ object FunctionRegistry { expression[Average]("avg"), expression[Count]("count"), expression[First]("first"), + expression[First]("first_value"), expression[Last]("last"), + expression[Last]("last_value"), expression[Max]("max"), expression[Min]("min"), expression[Sum]("sum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba1..1f5c0fab08dc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -113,6 +113,14 @@ case class Count(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentCount, LongType) } +/** + * Returns the first value of `child` for a group of rows. If the first value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ case class First(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil @@ -130,23 +138,36 @@ case class First(child: Expression) extends AlgebraicAggregate { private val first = AttributeReference("first", child.dataType)() - override val bufferAttributes = first :: Nil + private val valueSet = AttributeReference("valueSet", BooleanType)() + + override val bufferAttributes = first :: valueSet :: Nil override val initialValues = Seq( - /* first = */ Literal.create(null, child.dataType) + /* first = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) override val updateExpressions = Seq( - /* first = */ If(IsNull(first), child, first) + /* first = */ If(valueSet, first, child), + /* valueSet = */ If(valueSet, valueSet, Literal.create(true, BooleanType)) ) override val mergeExpressions = Seq( - /* first = */ If(IsNull(first.left), first.right, first.left) + /* first = */ If(valueSet, first.left, first.right), + /* valueSet = */ If(valueSet, valueSet, Literal.create(true, BooleanType)) ) override val evaluateExpression = first } +/** + * Returns the last value of `child` for a group of rows. If the last value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ case class Last(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil @@ -171,11 +192,11 @@ case class Last(child: Expression) extends AlgebraicAggregate { ) override val updateExpressions = Seq( - /* last = */ If(IsNull(child), last, child) + /* last = */ child ) override val mergeExpressions = Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) + /* last = */ last.right ) override val evaluateExpression = last diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9e..ccef0f945455d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -284,6 +284,36 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(11.125) :: Nil) } + test("first_value and last_value") { + // We force to use a single partition for the sort and aggregate to make result + // deterministic. + withSQLConf(("spark.sql.shuffle.partitions", "1")) { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key) + |FROM (SELECT key FROM agg1 ORDER BY key) tmp + """.stripMargin), + Row(null, 3, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key) + |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp + """.stripMargin), + Row(3, null, 3, null) :: Nil) + } + } + test("udaf") { checkAnswer( sqlContext.sql( From a0431715d9ad472b2cf6580e9e12d434b68d117d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 11 Aug 2015 21:18:23 -0700 Subject: [PATCH 2/7] Add ignoreNulls flag. --- .../expressions/aggregate/functions.scala | 89 +++++++++++++---- .../expressions/aggregate/utils.scala | 8 +- .../sql/catalyst/expressions/aggregates.scala | 95 +++++++++++++++---- .../spark/sql/expressions/WindowSpec.scala | 13 ++- .../execution/AggregationQuerySuite.scala | 16 +++- 5 files changed, 174 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 1f5c0fab08dc8..736d0324c2d04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -121,7 +122,15 @@ case class Count(child: Expression) extends AlgebraicAggregate { * a single partition, and we use a single reducer to do the aggregation.). * @param child */ -case class First(child: Expression) extends AlgebraicAggregate { +case class First(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggregate { + + def this(child: Expression) = this(child, false) + + def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + }) override def children: Seq[Expression] = child :: Nil @@ -147,17 +156,39 @@ case class First(child: Expression) extends AlgebraicAggregate { /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions = Seq( - /* first = */ If(valueSet, first, child), - /* valueSet = */ If(valueSet, valueSet, Literal.create(true, BooleanType)) - ) + override val updateExpressions = { + val litTrue = Literal.create(true, BooleanType) + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet, IsNull(child)), first, child), + /* valueSet = */ If(Or(valueSet, IsNull(child)), valueSet, litTrue) + ) + } else { + Seq( + /* first = */ If(valueSet, first, child), + /* valueSet = */ litTrue + ) + } + } - override val mergeExpressions = Seq( - /* first = */ If(valueSet, first.left, first.right), - /* valueSet = */ If(valueSet, valueSet, Literal.create(true, BooleanType)) - ) + override val mergeExpressions = { + val litTrue = Literal.create(true, BooleanType) + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet.left, IsNull(first.right)), first.left, first.right), + /* valueSet = */ If(Or(valueSet.left, IsNull(first.right)), valueSet.left, litTrue) + ) + } else { + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ litTrue + ) + } + } override val evaluateExpression = first + + override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}" } /** @@ -168,7 +199,15 @@ case class First(child: Expression) extends AlgebraicAggregate { * a single partition, and we use a single reducer to do the aggregation.). * @param child */ -case class Last(child: Expression) extends AlgebraicAggregate { +case class Last(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggregate { + + def this(child: Expression) = this(child, false) + + def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of Last should be a boolean literal.") + }) override def children: Seq[Expression] = child :: Nil @@ -191,15 +230,33 @@ case class Last(child: Expression) extends AlgebraicAggregate { /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions = Seq( - /* last = */ child - ) + override val updateExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(child), last, child) + ) + } else { + Seq( + /* last = */ child + ) + } + } - override val mergeExpressions = Seq( - /* last = */ last.right - ) + override val mergeExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + } else { + Seq( + /* last = */ last.right + ) + } + } override val evaluateExpression = last + + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" } case class Max(child: Expression) extends AlgebraicAggregate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a95490..df1d899898493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -61,15 +61,15 @@ object Utils { mode = aggregate.Complete, isDistinct = true) - case expressions.First(child) => + case expressions.First(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), + aggregateFunction = aggregate.First(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) - case expressions.Last(child) => + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), + aggregateFunction = aggregate.Last(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 2cf8312ea59aa..12f29da87c3f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +import org.codehaus.janino.Java.BooleanLiteral trait AggregateExpression extends Expression with Unevaluable @@ -630,59 +632,114 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class First( + child: Expression, + ignoreNulls: Boolean) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, false) + + def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + }) + override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" + override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})" override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() + val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() SplitEvaluation( - First(partialFirst.toAttribute), + First(partialFirst.toAttribute, ignoreNulls), partialFirst :: Nil) } - override def newInstance(): FirstFunction = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object First { + def apply(child: Expression): First = First(child, ignoreNulls = false) +} - var result: Any = null +case class FirstFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. + + private[this] var result: Any = null + + private[this] var valueSet: Boolean = false override def update(input: InternalRow): Unit = { - if (result == null) { - result = expr.eval(input) + if (!valueSet) { + val value = expr.eval(input) + // When we have not set the result, we will set the result if we respect nulls + // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. + if (!ignoreNulls || (ignoreNulls && value != null)) { + result = value + valueSet = true + } } } override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class Last( + child: Expression, + ignoreNulls: Boolean) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, false) + + def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of Last should be a boolean literal.") + }) + override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() + val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() SplitEvaluation( - Last(partialLast.toAttribute), + Last(partialLast.toAttribute, ignoreNulls), partialLast :: Nil) } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) } -case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object Last { + def apply(child: Expression): Last = Last(child, ignoreNulls = false) +} + +case class LastFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. var result: Any = null override def update(input: InternalRow): Unit = { - result = input + val value = expr.eval(input) + if (ignoreNulls && value != null) { + result = value + } else { + result = value + } } override def eval(input: InternalRow): Any = { - if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null + result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c3d2246297021..c2a0da4a86dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ @@ -149,13 +150,17 @@ class WindowSpec private[sql]( case Count(child) => WindowExpression( UnresolvedWindowFunction("count", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( + case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), + UnresolvedWindowFunction( + "first_value", + child :: Literal.create(ignoreNulls, BooleanType) :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( + case Last(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), + UnresolvedWindowFunction( + "last_value", + child :: Literal.create(ignoreNulls, BooleanType) :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case Min(child) => WindowExpression( UnresolvedWindowFunction("min", child :: Nil), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ccef0f945455d..4be6cbdc9df2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -295,10 +295,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | first_valUE(key), | lasT_value(key), | firSt(key), - | lASt(key) + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) |FROM (SELECT key FROM agg1 ORDER BY key) tmp """.stripMargin), - Row(null, 3, null, 3) :: Nil) + Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) checkAnswer( sqlContext.sql( @@ -307,10 +311,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | first_valUE(key), | lasT_value(key), | firSt(key), - | lASt(key) + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp """.stripMargin), - Row(3, null, 3, null) :: Nil) + Row(3, null, 3, null, 3, 1, 3, 1) :: Nil) } } From cc40a90123ba9c183d2869888175171a44d44dbb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 11 Aug 2015 22:03:07 -0700 Subject: [PATCH 3/7] Fix LastFunction's update. --- .../apache/spark/sql/catalyst/expressions/aggregates.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 12f29da87c3f7..77b6eca439ce9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -732,9 +732,7 @@ case class LastFunction( override def update(input: InternalRow): Unit = { val value = expr.eval(input) - if (ignoreNulls && value != null) { - result = value - } else { + if (!ignoreNulls || (ignoreNulls && value != null)) { result = value } } From ad0e120f87e47c912b7fa200ff2c55bcf21886a4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 11 Aug 2015 22:08:32 -0700 Subject: [PATCH 4/7] Remove unnecessary change. --- .../org/apache/spark/sql/catalyst/expressions/aggregates.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 77b6eca439ce9..b5ee27088d449 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, Genera import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -import org.codehaus.janino.Java.BooleanLiteral trait AggregateExpression extends Expression with Unevaluable From 90a72b7bfaa7caa9c791ea83457e924032329626 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 16:52:49 -0700 Subject: [PATCH 5/7] Update --- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 +- .../expressions/aggregate/functions.scala | 42 ++++++++----------- .../sql/catalyst/expressions/aggregates.scala | 24 +++++++---- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 39f554c137c98..227007bf5fe2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -110,7 +110,8 @@ trait CheckAnalysis { failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") + "Add to group by or wrap in first() (or first_value) if you don't care " + + "which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 736d0324c2d04..ea7c668af4d95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -122,15 +122,15 @@ case class Count(child: Expression) extends AlgebraicAggregate { * a single partition, and we use a single reducer to do the aggregation.). * @param child */ -case class First(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggregate { +case class First(child: Expression, ignoreNullsExpr: Expression) extends AlgebraicAggregate { - def this(child: Expression) = this(child, false) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + private val ignoreNulls: Boolean = ignoreNullsExpr match { case Literal(b: Boolean, BooleanType) => b case _ => throw new AnalysisException("The second argument of First should be a boolean literal.") - }) + } override def children: Seq[Expression] = child :: Nil @@ -157,33 +157,27 @@ case class First(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggre ) override val updateExpressions = { - val litTrue = Literal.create(true, BooleanType) if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ If(Or(valueSet, IsNull(child)), valueSet, litTrue) + /* valueSet = */ Or(valueSet, IsNotNull(child)) ) } else { Seq( /* first = */ If(valueSet, first, child), - /* valueSet = */ litTrue + /* valueSet = */ Literal.create(true, BooleanType) ) } } override val mergeExpressions = { - val litTrue = Literal.create(true, BooleanType) - if (ignoreNulls) { - Seq( - /* first = */ If(Or(valueSet.left, IsNull(first.right)), first.left, first.right), - /* valueSet = */ If(Or(valueSet.left, IsNull(first.right)), valueSet.left, litTrue) - ) - } else { - Seq( - /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ litTrue - ) - } + // For first, we can just check if valueSet.left is set to true. If it is set + // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // false, we are safe to do so because first.right will be null in this case). + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ Or(valueSet.left, valueSet.right) + ) } override val evaluateExpression = first @@ -199,15 +193,15 @@ case class First(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggre * a single partition, and we use a single reducer to do the aggregation.). * @param child */ -case class Last(child: Expression, ignoreNulls: Boolean) extends AlgebraicAggregate { +case class Last(child: Expression, ignoreNullsExpr: Expression) extends AlgebraicAggregate { - def this(child: Expression) = this(child, false) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + private val ignoreNulls: Boolean = ignoreNullsExpr match { case Literal(b: Boolean, BooleanType) => b case _ => - throw new AnalysisException("The second argument of Last should be a boolean literal.") - }) + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index b5ee27088d449..83e42e57f38e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -633,16 +633,16 @@ case class CombineSetsAndSumFunction( case class First( child: Expression, - ignoreNulls: Boolean) + ignoreNullsExpr: Expression) extends UnaryExpression with PartialAggregate1 { - def this(child: Expression) = this(child, false) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + private val ignoreNulls: Boolean = ignoreNullsExpr match { case Literal(b: Boolean, BooleanType) => b case _ => throw new AnalysisException("The second argument of First should be a boolean literal.") - }) + } override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -659,6 +659,9 @@ case class First( object First { def apply(child: Expression): First = First(child, ignoreNulls = false) + + def apply(child: Expression, ignoreNulls: Boolean): First = + First(child, Literal.create(ignoreNulls, BooleanType)) } case class FirstFunction( @@ -690,16 +693,16 @@ case class FirstFunction( case class Last( child: Expression, - ignoreNulls: Boolean) + ignoreNullsExpr: Expression) extends UnaryExpression with PartialAggregate1 { - def this(child: Expression) = this(child, false) + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - def this(child: Expression, ignoreNulls: Expression) = this(child, ignoreNulls match { + private val ignoreNulls: Boolean = ignoreNullsExpr match { case Literal(b: Boolean, BooleanType) => b case _ => - throw new AnalysisException("The second argument of Last should be a boolean literal.") - }) + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def references: AttributeSet = child.references override def nullable: Boolean = true @@ -717,6 +720,9 @@ case class Last( object Last { def apply(child: Expression): Last = Last(child, ignoreNulls = false) + + def apply(child: Expression, ignoreNulls: Boolean): Last = + Last(child, Literal.create(ignoreNulls, BooleanType)) } case class LastFunction( From ad0ac67874a9e2a7d266f4a792e7fcfe0f5617ce Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 17:49:09 -0700 Subject: [PATCH 6/7] Update test. --- .../apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4be6cbdc9df2c..f38ebaef79e8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -287,7 +287,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be test("first_value and last_value") { // We force to use a single partition for the sort and aggregate to make result // deterministic. - withSQLConf(("spark.sql.shuffle.partitions", "1")) { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer( sqlContext.sql( """ From f828bdf1612a5fc9466b9a7e80700d0dd94faaf5 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 20:28:45 -0700 Subject: [PATCH 7/7] Fix test. --- .../scala/org/apache/spark/sql/expressions/WindowSpec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c2a0da4a86dce..8b9247adea200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -154,13 +154,13 @@ class WindowSpec private[sql]( // TODO this is a hack for Hive UDAF first_value UnresolvedWindowFunction( "first_value", - child :: Literal.create(ignoreNulls, BooleanType) :: Nil), + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case Last(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF last_value UnresolvedWindowFunction( "last_value", - child :: Literal.create(ignoreNulls, BooleanType) :: Nil), + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case Min(child) => WindowExpression( UnresolvedWindowFunction("min", child :: Nil),