From ef357e1bd15ab9cd52cd66bcf2bce8b70aef22cb Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 18 May 2015 23:30:29 -0700 Subject: [PATCH 1/7] Add Cube / Rollup for dataframe --- .../org/apache/spark/sql/DataFrame.scala | 92 +++++++++++++++++++ .../org/apache/spark/sql/GroupedData.scala | 67 ++++++++++---- .../hive/HiveDataFrameAnalyticsSuite.scala | 77 ++++++++++++++++ 3 files changed, 219 insertions(+), 17 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index adad85806d1ea..6a61a68cd5ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -687,6 +687,46 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) + /** + * Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): GroupedData = new RollupedData(this, cols.map(_.expr)) + + /** + * Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): GroupedData = new CubedData(this, cols.map(_.expr)) + /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. @@ -713,6 +753,58 @@ class DataFrame private[sql]( new GroupedData(this, colNames.map(colName => resolve(colName))) } + /** + * Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + new RollupedData(this, colNames.map(colName => resolve(colName))) + } + + /** + * Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + new CubedData(this, colNames.map(colName => resolve(colName))) + } + /** * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 1381b9f1a6080..e2ca010facc2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,7 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -36,13 +36,22 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - val namedGroupingExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() + protected def aggregateExpressions(aggrExprs: Seq[NamedExpression]) + : Seq[NamedExpression] = { + if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + val retainedExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + retainedExprs ++ aggrExprs + } else { + aggrExprs } + } + + protected[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { DataFrame( - df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan)) } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) @@ -175,19 +184,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - val aggExprs = (expr +: exprs).map(_.expr).map { + (expr +: exprs).map(_.expr).map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() } - if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan)) - } else { - DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) - } } /** @@ -256,5 +256,38 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) @scala.annotation.varargs def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames:_*)(Sum) - } + } + +} + +/** + * :: Experimental :: + * A set of methods for aggregations on a [[DataFrame]] cube, created by [[DataFrame.cube]]. + * + * @since 1.4.0 + */ +@Experimental +class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedData(df, groupingExprs) { + + protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + DataFrame( + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) + } +} + +/** + * :: Experimental :: + * A set of methods for aggregations on a [[DataFrame]] rollup, created by [[DataFrame.rollup]]. + * + * @since 1.4.0 + */ +@Experimental +class RollupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedData(df, groupingExprs) { + + protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + DataFrame( + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala new file mode 100644 index 0000000000000..267a60f46783c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +case class TestData2Int(a: Int, b: Int) + +class HiveDataFrameAnalyticsSuiteSuite extends QueryTest { + val testData = + TestHive.sparkContext.parallelize( + TestData2Int(1, 2) :: + TestData2Int(2, 4) :: Nil).toDF() + + testData.registerTempTable("mytable") + + test("rollup") { + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + ) + + checkAnswer( + testData.rollup("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + ) + } + + test("cube") { + checkAnswer( + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + ) + + checkAnswer( + testData.cube("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + ) + } + + test("spark.sql.retainGroupColumns config") { + val oldConf = conf.getConf("spark.sql.retainGroupColumns", "true") + try { + conf.setConf("spark.sql.retainGroupColumns", "false") + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select sum(a-b) from mytable group by a + b, b with rollup").collect() + ) + + checkAnswer( + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select sum(a-b) from mytable group by a + b, b with cube").collect() + ) + } finally { + conf.setConf("spark.sql.retainGroupColumns", oldConf) + } + } +} From 279584cd6d5f675b29556d2118aeb3b394dfe126 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 19 May 2015 17:23:37 -0700 Subject: [PATCH 2/7] hiden the CubedData & RollupedData --- .../scala/org/apache/spark/sql/GroupedData.scala | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index e2ca010facc2e..c57fd93b17ac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -261,13 +261,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } /** - * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]] cube, created by [[DataFrame.cube]]. - * - * @since 1.4.0 */ -@Experimental -class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) +private[sql] class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) extends GroupedData(df, groupingExprs) { protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { @@ -277,13 +273,9 @@ class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } /** - * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]] rollup, created by [[DataFrame.rollup]]. - * - * @since 1.4.0 */ -@Experimental -class RollupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) +private[sql] class RollupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) extends GroupedData(df, groupingExprs) { protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { From 84c95642731636d8e6c36037b87608cf3123b60f Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 19 May 2015 17:34:52 -0700 Subject: [PATCH 3/7] Remove the CubedData & RollupedData --- .../org/apache/spark/sql/DataFrame.scala | 12 ++-- .../org/apache/spark/sql/GroupedData.scala | 63 ++++++++++--------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6a61a68cd5ec0..334d79816141f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -685,7 +685,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) + def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), GroupByType) /** * Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -705,7 +705,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = new RollupedData(this, cols.map(_.expr)) + def rollup(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), RollupType) /** * Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -725,7 +725,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = new CubedData(this, cols.map(_.expr)) + def cube(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), CubeType) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -750,7 +750,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName))) + new GroupedData(this, colNames.map(colName => resolve(colName)), GroupByType) } /** @@ -776,7 +776,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def rollup(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new RollupedData(this, colNames.map(colName => resolve(colName))) + new GroupedData(this, colNames.map(colName => resolve(colName)), RollupType) } /** @@ -802,7 +802,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def cube(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new CubedData(this, colNames.map(colName => resolve(colName))) + new GroupedData(this, colNames.map(colName => resolve(colName)), CubeType) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c57fd93b17ac5..540d18b440d3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType +/** + * The Grouping Type + */ +sealed private[sql] trait GroupType + +/** + * To indicate it's the GroupBy + */ +private[sql] object GroupByType extends GroupType + +/** + * To indicate it's the CUBE + */ +private[sql] object CubeType extends GroupType + +/** + * To indicate it's the ROLLUP + */ +private[sql] object RollupType extends GroupType /** * :: Experimental :: @@ -34,10 +53,13 @@ import org.apache.spark.sql.types.NumericType * @since 1.3.0 */ @Experimental -class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { +class GroupedData protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType) { protected def aggregateExpressions(aggrExprs: Seq[NamedExpression]) - : Seq[NamedExpression] = { + : Seq[NamedExpression] = { if (df.sqlContext.conf.dataFrameRetainGroupColumns) { val retainedExprs = groupingExprs.map { case expr: NamedExpression => expr @@ -50,8 +72,17 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } protected[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan)) + groupType match { + case GroupByType => + DataFrame( + df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan)) + case RollupType => + DataFrame( + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) + case CubeType => + DataFrame( + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) + } } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) @@ -259,27 +290,3 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } } - -/** - * A set of methods for aggregations on a [[DataFrame]] cube, created by [[DataFrame.cube]]. - */ -private[sql] class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) - extends GroupedData(df, groupingExprs) { - - protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) - } -} - -/** - * A set of methods for aggregations on a [[DataFrame]] rollup, created by [[DataFrame.rollup]]. - */ -private[sql] class RollupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) - extends GroupedData(df, groupingExprs) { - - protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) - } -} From c441777bf0d6d7d3b6c94161802ca9fe894aa0b7 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 20 May 2015 00:00:01 -0700 Subject: [PATCH 4/7] update the code as suggested --- .../org/apache/spark/sql/DataFrame.scala | 32 +++++--- .../org/apache/spark/sql/GroupedData.scala | 78 +++++++++++-------- .../hive/HiveDataFrameAnalyticsSuite.scala | 23 +----- 3 files changed, 69 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 334d79816141f..d78b4c2f8909c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -685,10 +685,13 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), GroupByType) + def groupBy(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) + } /** - * Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. * * {{{ @@ -705,10 +708,13 @@ class DataFrame private[sql]( * @since 1.4.0 */ @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), RollupType) + def rollup(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.RollupType) + } /** - * Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. * * {{{ @@ -725,7 +731,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), CubeType) + def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -750,14 +756,15 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName)), GroupByType) + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType) } /** - * Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. * - * This is a variant of groupBy that can only group by existing columns using column names + * This is a variant of rollup that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ @@ -776,14 +783,15 @@ class DataFrame private[sql]( @scala.annotation.varargs def rollup(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName)), RollupType) + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType) } /** - * Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. * - * This is a variant of groupBy that can only group by existing columns using column names + * This is a variant of cube that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ @@ -802,7 +810,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def cube(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName)), CubeType) + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 540d18b440d3d..b59f42fd3cb78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -32,19 +32,31 @@ import org.apache.spark.sql.types.NumericType sealed private[sql] trait GroupType /** - * To indicate it's the GroupBy + * Companion object for GroupedData */ -private[sql] object GroupByType extends GroupType +private[sql] object GroupedData { + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): GroupedData = { + new GroupedData(df, groupingExprs).withNewGroupType(groupType) + } -/** - * To indicate it's the CUBE - */ -private[sql] object CubeType extends GroupType + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType -/** - * To indicate it's the ROLLUP - */ -private[sql] object RollupType extends GroupType + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType +} /** * :: Experimental :: @@ -53,35 +65,36 @@ private[sql] object RollupType extends GroupType * @since 1.3.0 */ @Experimental -class GroupedData protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupType) { +class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - protected def aggregateExpressions(aggrExprs: Seq[NamedExpression]) - : Seq[NamedExpression] = { - if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggrExprs - } else { - aggrExprs - } + private var groupType: GroupType = _ + + private[sql] def withNewGroupType(groupType: GroupType): GroupedData = { + this.groupType = groupType + this } - protected[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + val retainedExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + retainedExprs ++ aggExprs + } else { + aggExprs + } + groupType match { - case GroupByType => + case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan)) - case RollupType => + df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) - case CubeType => + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs))) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) } } @@ -288,5 +301,4 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames:_*)(Sum) } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 267a60f46783c..3ad05f482504c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -25,7 +25,10 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ case class TestData2Int(a: Int, b: Int) -class HiveDataFrameAnalyticsSuiteSuite extends QueryTest { +// TODO ideally we should put the test suite into the package `sql`, as +// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't +// support the `cube` or `rollup` yet. +class HiveDataFrameAnalyticsSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( TestData2Int(1, 2) :: @@ -56,22 +59,4 @@ class HiveDataFrameAnalyticsSuiteSuite extends QueryTest { sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } - - test("spark.sql.retainGroupColumns config") { - val oldConf = conf.getConf("spark.sql.retainGroupColumns", "true") - try { - conf.setConf("spark.sql.retainGroupColumns", "false") - checkAnswer( - testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select sum(a-b) from mytable group by a + b, b with rollup").collect() - ) - - checkAnswer( - testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select sum(a-b) from mytable group by a + b, b with cube").collect() - ) - } finally { - conf.setConf("spark.sql.retainGroupColumns", oldConf) - } - } } From a2869d4a70621cc7daa336f5d81f5b4a68a795f0 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 20 May 2015 00:16:17 -0700 Subject: [PATCH 5/7] update the code as comments --- .../org/apache/spark/sql/GroupedData.scala | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index b59f42fd3cb78..dedc54927214c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -26,11 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType -/** - * The Grouping Type - */ -sealed private[sql] trait GroupType - /** * Companion object for GroupedData */ @@ -39,9 +34,14 @@ private[sql] object GroupedData { df: DataFrame, groupingExprs: Seq[Expression], groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs).withNewGroupType(groupType) + new GroupedData(df, groupingExprs, groupType: GroupType) } + /** + * The Grouping Type + */ + private[sql] trait GroupType + /** * To indicate it's the GroupBy */ @@ -65,15 +65,10 @@ private[sql] object GroupedData { * @since 1.3.0 */ @Experimental -class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - - private var groupType: GroupType = _ - - private[sql] def withNewGroupType(groupType: GroupType): GroupedData = { - this.groupType = groupType - this - } - +class GroupedData protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + private val groupType: GroupedData.GroupType) { private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { val retainedExprs = groupingExprs.map { From a66e38fa6bde0f154d8fb6761cfc9bdbf8c94f2c Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 20 May 2015 00:32:09 -0700 Subject: [PATCH 6/7] remove the unnecessary code changes --- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index dedc54927214c..997634491c30d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,22 +40,22 @@ private[sql] object GroupedData { /** * The Grouping Type */ - private[sql] trait GroupType + trait GroupType /** * To indicate it's the GroupBy */ - private[sql] object GroupByType extends GroupType + object GroupByType extends GroupType /** * To indicate it's the CUBE */ - private[sql] object CubeType extends GroupType + object CubeType extends GroupType /** * To indicate it's the ROLLUP */ - private[sql] object RollupType extends GroupType + object RollupType extends GroupType } /** From 73023197fa46c2c62a7fb2ee25d2896eaa9b339e Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 20 May 2015 00:52:25 -0700 Subject: [PATCH 7/7] cancel the implicit keyword --- .../org/apache/spark/sql/GroupedData.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 997634491c30d..f730e4ae00e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -69,7 +69,8 @@ class GroupedData protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + + private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { val retainedExprs = groupingExprs.map { case expr: NamedExpression => expr @@ -94,7 +95,7 @@ class GroupedData protected[sql]( } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) - : Seq[NamedExpression] = { + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -111,10 +112,10 @@ class GroupedData protected[sql]( namedExpr } } - columnExprs.map { c => + toDF(columnExprs.map { c => val a = f(c) Alias(a, a.prettyString)() - } + }) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -167,10 +168,10 @@ class GroupedData protected[sql]( * @since 1.3.0 */ def agg(exprs: Map[String, String]): DataFrame = { - exprs.map { case (colName, expr) => + toDF(exprs.map { case (colName, expr) => val a = strToExpr(expr)(df(colName).expr) Alias(a, a.prettyString)() - }.toSeq + }.toSeq) } /** @@ -223,10 +224,10 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - (expr +: exprs).map(_.expr).map { + toDF((expr +: exprs).map(_.expr).map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() - } + }) } /** @@ -235,7 +236,7 @@ class GroupedData protected[sql]( * * @since 1.3.0 */ - def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`.