From cb34a95e3dea152250b6409827fc869bd7fae407 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 4 Sep 2015 00:47:21 +0800 Subject: [PATCH 01/10] Add corr aggregate function. --- .../expressions/aggregate/functions.scala | 100 ++++++++++++++++++ .../expressions/aggregate/utils.scala | 6 ++ .../sql/catalyst/expressions/aggregates.scala | 13 +++ .../org/apache/spark/sql/functions.scala | 18 ++++ .../execution/AggregationQuerySuite.scala | 24 +++++ 5 files changed, 161 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba..c12a934e831c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -302,3 +303,102 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentSum, resultType) } + +case class Corr(left: Expression, right: Expression) extends AggregateFunction2 { + + def children: Seq[Expression] = Seq(left, right) + + def nullable: Boolean = false + + def dataType: DataType = DoubleType + + def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) + + def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + + def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance()) + + val bufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("MkX", DoubleType)(), + AttributeReference("MkY", DoubleType)(), + AttributeReference("count", LongType)()) + + override def initialize(buffer: MutableRow): Unit = { + (0 until 5).map(idx => buffer.setDouble(mutableBufferOffset + idx, 0.0)) + buffer.setLong(mutableBufferOffset + 5, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val x = left.eval(input).asInstanceOf[Double] + val y = right.eval(input).asInstanceOf[Double] + + var xAvg = buffer.getDouble(mutableBufferOffset) + var yAvg = buffer.getDouble(mutableBufferOffset + 1) + var Ck = buffer.getDouble(mutableBufferOffset + 2) + var MkX = buffer.getDouble(mutableBufferOffset + 3) + var MkY = buffer.getDouble(mutableBufferOffset + 4) + var count = buffer.getLong(mutableBufferOffset + 5) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + + buffer.setDouble(mutableBufferOffset, xAvg) + buffer.setDouble(mutableBufferOffset + 1, yAvg) + buffer.setDouble(mutableBufferOffset + 2, Ck) + buffer.setDouble(mutableBufferOffset + 3, MkX) + buffer.setDouble(mutableBufferOffset + 4, MkY) + buffer.setLong(mutableBufferOffset + 5, count) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputBufferOffset + 5) + + if (count2 > 0) { + var xAvg = buffer1.getDouble(mutableBufferOffset) + var yAvg = buffer1.getDouble(mutableBufferOffset + 1) + var Ck = buffer1.getDouble(mutableBufferOffset + 2) + var MkX = buffer1.getDouble(mutableBufferOffset + 3) + var MkY = buffer1.getDouble(mutableBufferOffset + 4) + var count = buffer1.getLong(mutableBufferOffset + 5) + + val xAvg2 = buffer2.getDouble(inputBufferOffset) + val yAvg2 = buffer2.getDouble(inputBufferOffset + 1) + val Ck2 = buffer2.getDouble(inputBufferOffset + 2) + val MkX2 = buffer2.getDouble(inputBufferOffset + 3) + val MkY2 = buffer2.getDouble(inputBufferOffset + 4) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 + MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 + count = totalCount + + buffer1.setDouble(mutableBufferOffset, xAvg) + buffer1.setDouble(mutableBufferOffset + 1, yAvg) + buffer1.setDouble(mutableBufferOffset + 2, Ck) + buffer1.setDouble(mutableBufferOffset + 3, MkX) + buffer1.setDouble(mutableBufferOffset + 4, MkY) + buffer1.setLong(mutableBufferOffset + 5, count) + } + } + + override def eval(buffer: InternalRow): Any = { + val Ck = buffer.getDouble(mutableBufferOffset + 2) + val MkX = buffer.getDouble(mutableBufferOffset + 3) + val MkY = buffer.getDouble(mutableBufferOffset + 4) + Ck / math.sqrt(MkX * MkY) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a9549..7cceee707cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -96,6 +96,12 @@ object Utils { aggregateFunction = aggregate.Sum(child), mode = aggregate.Complete, isDistinct = true) + + case expressions.Corr(left, right) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Corr(left, right), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9c..5f84afe9fd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,16 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +/** + * Calculate Pearson Correlation Coefficient for the given columns. + * Only support AggregateExpression2. + * + */ +case class Corr( + left: Expression, + right: Expression) extends BinaryExpression with AggregateExpression { + override def nullable: Boolean = false + override def dataType: DoubleType.type = DoubleType + override def toString: String = s"CORRELATION($left, $right)" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c..4771e16d23df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -172,6 +172,24 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(column1: Column, column2: Column): Column = + Corr(column1.expr, column2.expr) + + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(columnName1: String, columnName2: String): Column = + corr(Column(columnName1), Column(columnName2)) + /** * Aggregate function: returns the number of items in a group. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4886a8594836..cac5a1198e3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -480,6 +481,29 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(0, null, 1, 1, null, 0) :: Nil) } + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + } + test("test Last implemented based on AggregateExpression1") { // TODO: Remove this test once we remove AggregateExpression1. import org.apache.spark.sql.functions._ From d3e441457f8c0243170fa2f6a8408c0c1ed6bc99 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Oct 2015 08:13:19 +0800 Subject: [PATCH 02/10] Fix merging error. --- .../expressions/aggregate/functions.scala | 103 +++++++++++++++--- .../expressions/aggregate/utils.scala | 8 +- .../sql/catalyst/expressions/aggregates.scala | 95 ++++++++++++---- .../spark/sql/expressions/WindowSpec.scala | 1 - .../execution/AggregationQuerySuite.scala | 38 +++++++ 5 files changed, 204 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 6db253900b81..2907e29f0a47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -121,7 +121,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = Cast(currentCount, LongType) } -case class First(child: Expression) extends DeclarativeAggregate { +/** + * Returns the first value of `child` for a group of rows. If the first value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ +case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def children: Seq[Expression] = child :: Nil @@ -138,24 +154,61 @@ case class First(child: Expression) extends DeclarativeAggregate { private val first = AttributeReference("first", child.dataType)() - override val aggBufferAttributes = first :: Nil + private val valueSet = AttributeReference("valueSet", BooleanType)() + + override val aggBufferAttributes = first :: valueSet :: Nil override val initialValues = Seq( - /* first = */ Literal.create(null, child.dataType) + /* first = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions = Seq( - /* first = */ If(IsNull(first), child, first) - ) + override val updateExpressions = { + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet, IsNull(child)), first, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) + ) + } else { + Seq( + /* first = */ If(valueSet, first, child), + /* valueSet = */ Literal.create(true, BooleanType) + ) + } + } - override val mergeExpressions = Seq( - /* first = */ If(IsNull(first.left), first.right, first.left) - ) + override val mergeExpressions = { + // For first, we can just check if valueSet.left is set to true. If it is set + // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // false, we are safe to do so because first.right will be null in this case). + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ Or(valueSet.left, valueSet.right) + ) + } override val evaluateExpression = first + + override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}" } -case class Last(child: Expression) extends DeclarativeAggregate { +/** + * Returns the last value of `child` for a group of rows. If the last value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + * @param child + */ +case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } override def children: Seq[Expression] = child :: Nil @@ -178,15 +231,33 @@ case class Last(child: Expression) extends DeclarativeAggregate { /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions = Seq( - /* last = */ If(IsNull(child), last, child) - ) + override val updateExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(child), last, child) + ) + } else { + Seq( + /* last = */ child + ) + } + } - override val mergeExpressions = Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) + override val mergeExpressions = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + } else { + Seq( + /* last = */ last.right + ) + } + } override val evaluateExpression = last + + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" } case class Max(child: Expression) extends DeclarativeAggregate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 540fc2310172..a70d5f3ca4d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -61,15 +61,15 @@ object Utils { mode = aggregate.Complete, isDistinct = true) - case expressions.First(child) => + case expressions.First(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), + aggregateFunction = aggregate.First(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) - case expressions.Last(child) => + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), + aggregateFunction = aggregate.Last(child, ignoreNulls), mode = aggregate.Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 31a906c41563..7ddfcae168aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} @@ -630,59 +631,113 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class First( + child: Expression, + ignoreNullsExpr: Expression) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" + override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})" override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() + val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() SplitEvaluation( - First(partialFirst.toAttribute), + First(partialFirst.toAttribute, ignoreNulls), partialFirst :: Nil) } - override def newInstance(): FirstFunction = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object First { + def apply(child: Expression): First = First(child, ignoreNulls = false) - var result: Any = null + def apply(child: Expression, ignoreNulls: Boolean): First = + First(child, Literal.create(ignoreNulls, BooleanType)) +} + +case class FirstFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. + + private[this] var result: Any = null + + private[this] var valueSet: Boolean = false override def update(input: InternalRow): Unit = { - // We ignore null values. - if (result == null) { - result = expr.eval(input) + if (!valueSet) { + val value = expr.eval(input) + // When we have not set the result, we will set the result if we respect nulls + // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. + if (!ignoreNulls || (ignoreNulls && value != null)) { + result = value + valueSet = true + } } } override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { +case class Last( + child: Expression, + ignoreNullsExpr: Expression) + extends UnaryExpression with PartialAggregate1 { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" + override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() + val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() SplitEvaluation( - Last(partialLast.toAttribute), + Last(partialLast.toAttribute, ignoreNulls), partialLast :: Nil) } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) } -case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. +object Last { + def apply(child: Expression): Last = Last(child, ignoreNulls = false) + + def apply(child: Expression, ignoreNulls: Boolean): Last = + Last(child, Literal.create(ignoreNulls, BooleanType)) +} + +case class LastFunction( + expr: Expression, + ignoreNulls: Boolean, + base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. var result: Any = null override def update(input: InternalRow): Unit = { val value = expr.eval(input) - // We ignore null values. - if (value != null) { + if (!ignoreNulls || (ignoreNulls && value != null)) { result = value } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8b9247adea20..78e9c5ebd4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 65244da103b7..9f92b0e02586 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -323,6 +323,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(11.125) :: Nil) } + test("first_value and last_value") { + // We force to use a single partition for the sort and aggregate to make result + // deterministic. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key) tmp + """.stripMargin), + Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | first_valUE(key), + | lasT_value(key), + | firSt(key), + | lASt(key), + | first_valUE(key, true), + | lasT_value(key, true), + | firSt(key, true), + | lASt(key, true) + |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp + """.stripMargin), + Row(3, null, 3, null, 3, 1, 3, 1) :: Nil) + } + } + test("udaf") { checkAnswer( sqlContext.sql( From d10afbefddacb8f779414948fc2217bdc6a9e791 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Oct 2015 08:14:32 +0800 Subject: [PATCH 03/10] Don't modify WindowSpec. --- .../main/scala/org/apache/spark/sql/expressions/WindowSpec.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 78e9c5ebd4c4..8b9247adea20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ From cc1657b9cf5da4f07804958420bb94a58c86593f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Oct 2015 08:29:52 +0800 Subject: [PATCH 04/10] Fix scala style. --- .../expressions/aggregate/functions.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 2907e29f0a47..37ee1aa214cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -531,19 +531,19 @@ case class Corr( mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends ImperativeAggregate { - + def children: Seq[Expression] = Seq(left, right) - + def nullable: Boolean = false - + def dataType: DataType = DoubleType - + def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) - + def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - + def inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - + val aggBufferAttributes: Seq[AttributeReference] = Seq( AttributeReference("xAvg", DoubleType)(), AttributeReference("yAvg", DoubleType)(), @@ -562,18 +562,18 @@ case class Corr( (0 until 5).map(idx => buffer.setDouble(mutableAggBufferOffset + idx, 0.0)) buffer.setLong(mutableAggBufferOffset + 5, 0L) } - + override def update(buffer: MutableRow, input: InternalRow): Unit = { val x = left.eval(input).asInstanceOf[Double] val y = right.eval(input).asInstanceOf[Double] - + var xAvg = buffer.getDouble(mutableAggBufferOffset) var yAvg = buffer.getDouble(mutableAggBufferOffset + 1) var Ck = buffer.getDouble(mutableAggBufferOffset + 2) var MkX = buffer.getDouble(mutableAggBufferOffset + 3) var MkY = buffer.getDouble(mutableAggBufferOffset + 4) var count = buffer.getLong(mutableAggBufferOffset + 5) - + val deltaX = x - xAvg val deltaY = y - yAvg count += 1 @@ -582,7 +582,7 @@ case class Corr( Ck += deltaX * (y - yAvg) MkX += deltaX * (x - xAvg) MkY += deltaY * (y - yAvg) - + buffer.setDouble(mutableAggBufferOffset, xAvg) buffer.setDouble(mutableAggBufferOffset + 1, yAvg) buffer.setDouble(mutableAggBufferOffset + 2, Ck) @@ -590,10 +590,10 @@ case class Corr( buffer.setDouble(mutableAggBufferOffset + 4, MkY) buffer.setLong(mutableAggBufferOffset + 5, count) } - + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { val count2 = buffer2.getLong(inputAggBufferOffset + 5) - + if (count2 > 0) { var xAvg = buffer1.getDouble(mutableAggBufferOffset) var yAvg = buffer1.getDouble(mutableAggBufferOffset + 1) @@ -601,13 +601,13 @@ case class Corr( var MkX = buffer1.getDouble(mutableAggBufferOffset + 3) var MkY = buffer1.getDouble(mutableAggBufferOffset + 4) var count = buffer1.getLong(mutableAggBufferOffset + 5) - + val xAvg2 = buffer2.getDouble(inputAggBufferOffset) val yAvg2 = buffer2.getDouble(inputAggBufferOffset + 1) val Ck2 = buffer2.getDouble(inputAggBufferOffset + 2) val MkX2 = buffer2.getDouble(inputAggBufferOffset + 3) val MkY2 = buffer2.getDouble(inputAggBufferOffset + 4) - + val totalCount = count + count2 val deltaX = xAvg - xAvg2 val deltaY = yAvg - yAvg2 @@ -617,7 +617,7 @@ case class Corr( MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 count = totalCount - + buffer1.setDouble(mutableAggBufferOffset, xAvg) buffer1.setDouble(mutableAggBufferOffset + 1, yAvg) buffer1.setDouble(mutableAggBufferOffset + 2, Ck) From 02562f3a9ab9cda941b260a834d505f7cacd46f2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Oct 2015 17:09:45 +0800 Subject: [PATCH 05/10] Add document. Return NaN when count is zero. --- .../expressions/aggregate/functions.scala | 25 ++++++++++++++++--- .../execution/AggregationQuerySuite.scala | 4 +++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 37ee1aa214cd..5d2d73b782ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -525,6 +525,16 @@ case class Sum(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = Cast(currentSum, resultType) } +/** + * Compute Pearson correlation between two expressions. + * When applied on empty data (i.e., count is zero), it returns NaN. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + * + * @param left one of the expressions to compute correlation with. + * @param right another expression to compute correlation with. + */ case class Corr( left: Expression, right: Expression, @@ -591,6 +601,8 @@ case class Corr( buffer.setLong(mutableAggBufferOffset + 5, count) } + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { val count2 = buffer2.getLong(inputAggBufferOffset + 5) @@ -628,10 +640,15 @@ case class Corr( } override def eval(buffer: InternalRow): Any = { - val Ck = buffer.getDouble(mutableAggBufferOffset + 2) - val MkX = buffer.getDouble(mutableAggBufferOffset + 3) - val MkY = buffer.getDouble(mutableAggBufferOffset + 4) - Ck / math.sqrt(MkX * MkY) + val count = buffer.getLong(mutableAggBufferOffset + 5) + if (count > 0) { + val Ck = buffer.getDouble(mutableAggBufferOffset + 2) + val MkX = buffer.getDouble(mutableAggBufferOffset + 3) + val MkY = buffer.getDouble(mutableAggBufferOffset + 4) + Ck / math.sqrt(MkX * MkY) + } else { + Double.NaN + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 9f92b0e02586..6835ad01f0de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -577,6 +577,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + + val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") + val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(corr4.isNaN) } test("test Last implemented based on AggregateExpression1") { From 5fbcf9115e8e9677ea49e621804e18ae4a7a41df Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Oct 2015 01:37:21 +0800 Subject: [PATCH 06/10] For comments. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/functions.scala | 85 ++++++++++++------- .../expressions/aggregate/utils.scala | 26 ++++++ .../sql/catalyst/expressions/aggregates.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 44 +++++++--- .../execution/AggregationQuerySuite.scala | 14 +++ 6 files changed, 129 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3dce6c1a27e8..e4e6b9469198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -178,6 +178,7 @@ object FunctionRegistry { // aggregate functions expression[Average]("avg"), + expression[Corr]("corr"), expression[Count]("count"), expression[First]("first"), expression[First]("first_value"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 5d2d73b782ba..4b3b10ccb956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -548,7 +548,7 @@ case class Corr( def dataType: DataType = DoubleType - def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) @@ -562,6 +562,20 @@ case class Corr( AttributeReference("MkY", DoubleType)(), AttributeReference("count", LongType)()) + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 + private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 + private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 + private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 + private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 + private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 + private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 + private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 + private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -569,8 +583,12 @@ case class Corr( copy(inputAggBufferOffset = newInputAggBufferOffset) override def initialize(buffer: MutableRow): Unit = { - (0 until 5).map(idx => buffer.setDouble(mutableAggBufferOffset + idx, 0.0)) - buffer.setLong(mutableAggBufferOffset + 5, 0L) + buffer.setDouble(mutableAggBufferOffset, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) + buffer.setLong(mutableAggBufferOffsetPlus5, 0L) } override def update(buffer: MutableRow, input: InternalRow): Unit = { @@ -578,11 +596,11 @@ case class Corr( val y = right.eval(input).asInstanceOf[Double] var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffset + 1) - var Ck = buffer.getDouble(mutableAggBufferOffset + 2) - var MkX = buffer.getDouble(mutableAggBufferOffset + 3) - var MkY = buffer.getDouble(mutableAggBufferOffset + 4) - var count = buffer.getLong(mutableAggBufferOffset + 5) + var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer.getLong(mutableAggBufferOffsetPlus5) val deltaX = x - xAvg val deltaY = y - yAvg @@ -594,31 +612,34 @@ case class Corr( MkY += deltaY * (y - yAvg) buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffset + 1, yAvg) - buffer.setDouble(mutableAggBufferOffset + 2, Ck) - buffer.setDouble(mutableAggBufferOffset + 3, MkX) - buffer.setDouble(mutableAggBufferOffset + 4, MkY) - buffer.setLong(mutableAggBufferOffset + 5, count) + buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer.setLong(mutableAggBufferOffsetPlus5, count) } // Merge counters from other partitions. Formula can be found at: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffset + 5) + val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. if (count2 > 0) { var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffset + 1) - var Ck = buffer1.getDouble(mutableAggBufferOffset + 2) - var MkX = buffer1.getDouble(mutableAggBufferOffset + 3) - var MkY = buffer1.getDouble(mutableAggBufferOffset + 4) - var count = buffer1.getLong(mutableAggBufferOffset + 5) + var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer1.getLong(mutableAggBufferOffsetPlus5) val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffset + 1) - val Ck2 = buffer2.getDouble(inputAggBufferOffset + 2) - val MkX2 = buffer2.getDouble(inputAggBufferOffset + 3) - val MkY2 = buffer2.getDouble(inputAggBufferOffset + 4) + val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) + val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) + val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) + val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) val totalCount = count + count2 val deltaX = xAvg - xAvg2 @@ -631,20 +652,20 @@ case class Corr( count = totalCount buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffset + 1, yAvg) - buffer1.setDouble(mutableAggBufferOffset + 2, Ck) - buffer1.setDouble(mutableAggBufferOffset + 3, MkX) - buffer1.setDouble(mutableAggBufferOffset + 4, MkY) - buffer1.setLong(mutableAggBufferOffset + 5, count) + buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer1.setLong(mutableAggBufferOffsetPlus5, count) } } override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffset + 5) + val count = buffer.getLong(mutableAggBufferOffsetPlus5) if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffset + 2) - val MkX = buffer.getDouble(mutableAggBufferOffset + 3) - val MkY = buffer.getDouble(mutableAggBufferOffset + 4) + val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) Ck / math.sqrt(MkX * MkY) } else { Double.NaN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index a70d5f3ca4d3..ff45aa8977c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -194,4 +194,30 @@ object Utils { } case other => None } + + def mustNewAggregation(aggregate: Aggregate): Unit = { + val onlyForAggregateExpression2 = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + // If an aggregate expression only extends AggregateExpression + // without AggregateExpression1, it indicates it only supports AggregateExpression2 + case agg: expressions.AggregateExpression + if !agg.isInstanceOf[expressions.AggregateExpression1] => + agg + } + } + if (onlyForAggregateExpression2.nonEmpty) { + val invalidFunctions = { + if (onlyForAggregateExpression2.length > 1) { + s"${onlyForAggregateExpression2.tail.map(_.nodeName).mkString(",")} " + + s"and ${onlyForAggregateExpression2.head.nodeName} are" + } else { + s"${onlyForAggregateExpression2.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} only implemented based on the new Aggregate Function " + + s"interface and it cannot be used when spark.sql.useAggregate2 = false." + throw new AnalysisException(errorMessage) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 9641fc9a815a..4f19f54a07ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -752,12 +752,12 @@ case class LastFunction( * Only support AggregateExpression2. * */ -case class Corr( - left: Expression, - right: Expression) extends BinaryExpression with AggregateExpression { +case class Corr(left: Expression, right: Expression) + extends BinaryExpression with AggregateExpression with ImplicitCastInputTypes { override def nullable: Boolean = false override def dataType: DoubleType.type = DoubleType override def toString: String = s"CORRELATION($left, $right)" + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) } // Compute standard deviation based on online algorithm specified here: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 637deff4e220..f9c89556dd73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -156,6 +156,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, partialComputation, child) if !canBeConvertedToNewAggregation(plan) => + Utils.mustNewAggregation(plan.asInstanceOf[logical.Aggregate]) execution.Aggregate( partial = false, namedGroupingAttributes, @@ -294,25 +295,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemi => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -321,6 +321,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { @@ -379,6 +394,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, + leftGroup, rightGroup, left, right) => + execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { @@ -414,6 +433,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Nil } else { Utils.checkInvalidAggregateFunction2(a) + Utils.mustNewAggregation(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6835ad01f0de..da8726d9f190 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -581,6 +581,20 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) assert(corr4.isNaN) + + val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") + val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr5 - 1.0) < 1e-12) + val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr6 + 1.0) < 1e-12) + + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + val errorMessage = intercept[AnalysisException] { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + }.getMessage + assert(errorMessage.contains("Corr is only implemented based on the new Aggregate Function")) + } } test("test Last implemented based on AggregateExpression1") { From 3b731e2c9b08dbade38da73bcff94cf1b2cd7636 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Oct 2015 07:13:54 +0800 Subject: [PATCH 07/10] Make Corr extends AggregateExpression1. --- .../expressions/aggregate/functions.scala | 4 +-- .../expressions/aggregate/utils.scala | 26 ------------------- .../sql/catalyst/expressions/aggregates.scala | 7 ++++- .../spark/sql/execution/SparkStrategies.scala | 2 -- .../execution/AggregationQuerySuite.scala | 10 ++++--- 5 files changed, 14 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 4b3b10ccb956..35322d154be0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -527,7 +527,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate { /** * Compute Pearson correlation between two expressions. - * When applied on empty data (i.e., count is zero), it returns NaN. + * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient @@ -668,7 +668,7 @@ case class Corr( val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) Ck / math.sqrt(MkX * MkY) } else { - Double.NaN + null } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index ff45aa8977c1..a70d5f3ca4d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -194,30 +194,4 @@ object Utils { } case other => None } - - def mustNewAggregation(aggregate: Aggregate): Unit = { - val onlyForAggregateExpression2 = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - // If an aggregate expression only extends AggregateExpression - // without AggregateExpression1, it indicates it only supports AggregateExpression2 - case agg: expressions.AggregateExpression - if !agg.isInstanceOf[expressions.AggregateExpression1] => - agg - } - } - if (onlyForAggregateExpression2.nonEmpty) { - val invalidFunctions = { - if (onlyForAggregateExpression2.length > 1) { - s"${onlyForAggregateExpression2.tail.map(_.nodeName).mkString(",")} " + - s"and ${onlyForAggregateExpression2.head.nodeName} are" - } else { - s"${onlyForAggregateExpression2.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} only implemented based on the new Aggregate Function " + - s"interface and it cannot be used when spark.sql.useAggregate2 = false." - throw new AnalysisException(errorMessage) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 4f19f54a07ab..47c4ee8658ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -753,11 +753,16 @@ case class LastFunction( * */ case class Corr(left: Expression, right: Expression) - extends BinaryExpression with AggregateExpression with ImplicitCastInputTypes { + extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { override def nullable: Boolean = false override def dataType: DoubleType.type = DoubleType override def toString: String = s"CORRELATION($left, $right)" override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException( + "Corr only supports the new AggregateExpression2 and can only be used " + + "when spark.sql.useAggregate2 = true") + } } // Compute standard deviation based on online algorithm specified here: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f9c89556dd73..86d1d390f191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -156,7 +156,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { groupingExpressions, partialComputation, child) if !canBeConvertedToNewAggregation(plan) => - Utils.mustNewAggregation(plan.asInstanceOf[logical.Aggregate]) execution.Aggregate( partial = false, namedGroupingAttributes, @@ -433,7 +432,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Nil } else { Utils.checkInvalidAggregateFunction2(a) - Utils.mustNewAggregation(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index da8726d9f190..e3d2ee19448a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate @@ -579,8 +580,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") - val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) - assert(corr4.isNaN) + val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr4 == Row(null)) val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) @@ -589,11 +590,12 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr6 + 1.0) < 1e-12) withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - val errorMessage = intercept[AnalysisException] { + val errorMessage = intercept[SparkException] { val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) }.getMessage - assert(errorMessage.contains("Corr is only implemented based on the new Aggregate Function")) + assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + + "Corr only supports the new AggregateExpression2")) } } From 4f8c381c48244e46ac6437b83e70eccbf42ac907 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Oct 2015 12:59:18 +0800 Subject: [PATCH 08/10] Fix null case. --- .../expressions/aggregate/functions.scala | 55 ++++++++++--------- .../execution/AggregationQuerySuite.scala | 7 +++ 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 35322d154be0..67e6f4504c1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -592,31 +592,36 @@ case class Corr( } override def update(buffer: MutableRow, input: InternalRow): Unit = { - val x = left.eval(input).asInstanceOf[Double] - val y = right.eval(input).asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(mutableAggBufferOffset) + var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer.getLong(mutableAggBufferOffsetPlus5) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + + buffer.setDouble(mutableAggBufferOffset, xAvg) + buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer.setLong(mutableAggBufferOffsetPlus5, count) + } } // Merge counters from other partitions. Formula can be found at: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index e3d2ee19448a..61b688b38d0b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -589,6 +589,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) assert(math.abs(corr6 + 1.0) < 1e-12) + val df5 = Seq[(Integer, Integer)]( + (1, null), + (null, -60)).toDF("a", "b") + + val corr7 = df5.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr7 == Row(null)) + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { val errorMessage = intercept[SparkException] { val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") From 7dcf689cca8ae4bd89a1f7c0b28373c8cfb6b8f0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Oct 2015 15:51:52 +0800 Subject: [PATCH 09/10] Fix udaf_corr test. --- .../sql/catalyst/expressions/aggregate/functions.scala | 7 ++++++- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 67e6f4504c1d..6ed348ac6836 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -671,7 +671,12 @@ case class Corr( val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - Ck / math.sqrt(MkX * MkY) + val corr = Ck / math.sqrt(MkX * MkY) + if (corr.isNaN) { + null + } else { + corr + } } else { null } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 61b688b38d0b..a4871ad73513 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -596,6 +596,12 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr7 = df5.groupBy().agg(corr("a", "b")).collect()(0) assert(corr7 == Row(null)) + val df6 = Seq[(Integer, Integer)]( + (7, 12)).toDF("a", "b") + + val corr8 = df6.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr8 == Row(null)) + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { val errorMessage = intercept[SparkException] { val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") From 2de76b444456bc7e751fa9ccb85a6e8f0662ff76 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 31 Oct 2015 01:32:20 +0800 Subject: [PATCH 10/10] Due to numerical errors, put udaf_corr in HiveCompatibilitySuite to blacklist and add these tests to AggregationQuerySuite. --- .../execution/HiveCompatibilitySuite.scala | 7 +- .../execution/AggregationQuerySuite.scala | 66 ++++++++++++++++--- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index eed9e436f9af..110f6d1ffd89 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -304,7 +304,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // classpath problems "compute_stats.*", - "udf_bitmap_.*" + "udf_bitmap_.*", + + // The difference between the double numbers generated by Hive and Spark + // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) + "udaf_corr" ) /** @@ -858,7 +862,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_cast_1", "type_widening", "udaf_collect_set", - "udaf_corr", "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a4871ad73513..0cf0e0aab9eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -589,18 +589,66 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) assert(math.abs(corr6 + 1.0) < 1e-12) - val df5 = Seq[(Integer, Integer)]( - (1, null), - (null, -60)).toDF("a", "b") + // Test for udaf_corr in HiveCompatibilitySuite + // udaf_corr has been blacklisted due to numerical errors + // We test it here: + // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL + // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; => + // 1 NULL + // 2 NULL + // 3 NULL + // 4 NULL + // 5 NULL + // 6 NULL + // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323 + + val covar_tab = Seq[(Integer, Integer, Integer)]( + (1, null, 15), + (2, 3, null), + (3, 7, 12), + (4, 4, 14), + (5, 8, 17), + (6, 2, 11)).toDF("a", "b", "c") + + covar_tab.registerTempTable("covar_tab") - val corr7 = df5.groupBy().agg(corr("a", "b")).collect()(0) - assert(corr7 == Row(null)) + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 1 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 3 + """.stripMargin), + Row(null) :: Nil) - val df6 = Seq[(Integer, Integer)]( - (7, 12)).toDF("a", "b") + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a = 3 + """.stripMargin), + Row(null) :: Nil) - val corr8 = df6.groupBy().agg(corr("a", "b")).collect()(0) - assert(corr8 == Row(null)) + checkAnswer( + sqlContext.sql( + """ + |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a + """.stripMargin), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, null) :: + Row(5, null) :: + Row(6, null) :: Nil) + + val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { val errorMessage = intercept[SparkException] {