diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d700fb83b9b7..199acc23b79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.{lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -72,7 +73,8 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, + Pivot(Some(aliasedGrps), pivotCol, values.map(_.expr), aggExprs, df.logicalPlan)) } } @@ -335,7 +337,7 @@ class RelationalGroupedDataset protected[sql]( * @since 1.6.0 */ def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + pivot(Column(pivotColumn), values.map(lit)) } /** @@ -359,7 +361,7 @@ class RelationalGroupedDataset protected[sql]( * @since 1.6.0 */ def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + pivot(Column(pivotColumn), values.asScala.map(lit)) } /** @@ -371,6 +373,12 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy($"year").pivot($"course").sum($"earnings"); * }}} * + * For pivoting by multiple columns, use the `struct` function to combine the columns: + * + * {{{ + * df.groupBy($"year").pivot(struct($"course", $"training")).agg(sum($"earnings")) + * }}} + * * @param pivotColumn he column to pivot. * @since 2.4.0 */ @@ -384,6 +392,10 @@ class RelationalGroupedDataset protected[sql]( .sort(pivotColumn) // ensure that the output columns are in a consistent logical order .collect() .map(_.get(0)) + .collect { + case row: GenericRow => struct(row.values.map(lit): _*) + case value => lit(value) + } .toSeq if (values.length > maxValues) { @@ -403,20 +415,29 @@ class RelationalGroupedDataset protected[sql]( * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") + * df.groupBy($"year").pivot($"course", Seq(lit("dotNET"), lit("Java"))).sum($"earnings") + * }}} + * + * For pivoting by multiple columns, use the `struct` function to combine the columns and values: + * + * {{{ + * df + * .groupBy($"year") + * .pivot(struct($"course", $"training"), Seq(struct(lit("java"), lit("Experts")))) + * .agg(sum($"earnings")) * }}} * * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Column]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values)) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -433,7 +454,7 @@ class RelationalGroupedDataset protected[sql]( * @param values List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: java.util.List[Column]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } @@ -561,5 +582,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Column]) extends GroupType } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 69a2904f5f3f..8ae6de3935d1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -306,6 +306,22 @@ public void pivot() { Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } + @Test + public void pivotColumnValues() { + Dataset df = spark.table("courseSales"); + List actual = df.groupBy("year") + .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java"))) + .agg(sum("earnings")).orderBy("year").collectAsList(); + + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); + + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); + } + private String getResource(String resource) { try { // The following "getResource" has different behaviors in SBT and Maven. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9ef93e5..4a7589a03c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -33,7 +33,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { .agg(sum($"earnings")), expected) checkAnswer( - courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + courseSales.groupBy($"year").pivot($"course", Seq(lit("dotNET"), lit("Java"))) .agg(sum($"earnings")), expected) } @@ -44,7 +44,10 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), expected) checkAnswer( - courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + courseSales + .groupBy('course) + .pivot('year, Seq(lit(2012), lit(2013))) + .agg(sum('earnings)), expected) } @@ -58,7 +61,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { expected) checkAnswer( courseSales.groupBy($"year") - .pivot($"course", Seq("dotNET", "Java")) + .pivot($"course", Seq(lit("dotNET"), lit("Java"))) .agg(sum($"earnings"), avg($"earnings")), expected) } @@ -204,7 +207,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), expected) checkAnswer( - complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + complexData.groupBy().pivot('b, Seq(lit(true), lit(false))).agg(max('a)), expected) } @@ -272,7 +275,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil val df = trainingSales .groupBy($"sales.year") - .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase).map(lit)) .agg(sum($"sales.earnings")) checkAnswer(df, expected) @@ -282,7 +285,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil val df = trainingSales .groupBy($"sales.year") - .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .pivot(concat_ws("-", $"training", $"sales.course"), Seq(lit("Experts-dotNET"))) .agg(sum($"sales.earnings")) checkAnswer(df, expected) @@ -292,7 +295,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil val df1 = trainingSales .groupBy($"sales.year") - .pivot(lit(123), Seq(123)) + .pivot(lit(123), Seq(lit(123))) .agg(sum($"sales.earnings")) checkAnswer(df1, expected) @@ -302,10 +305,34 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { val exception = intercept[AnalysisException] { trainingSales .groupBy($"sales.year") - .pivot(min($"training"), Seq("Experts")) + .pivot(min($"training"), Seq(lit("Experts"))) .agg(sum($"sales.earnings")) } assert(exception.getMessage.contains("aggregate functions are not allowed")) } + + test("pivoting column list with values") { + val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), Seq( + struct(lit("dotnet"), lit("Experts")), + struct(lit("java"), lit("Dummies"))) + ).agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("pivoting column list") { + val expected = Seq( + Row(2012, 5000.0, 10000.0, null, 20000.0), + Row(2013, null, 48000.0, 30000.0, null)) + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } }