Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{NumericType, StructType}

Expand Down Expand Up @@ -72,7 +73,8 @@ class RelationalGroupedDataset protected[sql](
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
Dataset.ofRows(
df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
df.sparkSession,
Pivot(Some(aliasedGrps), pivotCol, values.map(_.expr), aggExprs, df.logicalPlan))
}
}

Expand Down Expand Up @@ -335,7 +337,7 @@ class RelationalGroupedDataset protected[sql](
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
pivot(Column(pivotColumn), values.map(lit))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to allow pivot(String, Seq[Any]) also take Column. Did I misread the codes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you did. This "old" interface only takes in a single named column (say, "a", but not "a+1") by its name, but we turn it into a Column just to reuse the same implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we did allow only liternals but not generic columns before, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with Seq[Any] we only allow literal values, not Columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to allow pivot(String, Seq[Any]) also take Column

I think using "lit" here is causing the confusion then (perhaps @MaxGekk was not aware of that?). We should keep the current behavior of this signature as it is. Using Column(Literal.create(value)) would do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm? I think we should allow it and then make this #22030 (comment) assumption stay true.

}

/**
Expand All @@ -359,7 +361,7 @@ class RelationalGroupedDataset protected[sql](
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
pivot(Column(pivotColumn), values.asScala.map(lit))
}

/**
Expand All @@ -371,6 +373,12 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy($"year").pivot($"course").sum($"earnings");
* }}}
*
* For pivoting by multiple columns, use the `struct` function to combine the columns:
*
* {{{
* df.groupBy($"year").pivot(struct($"course", $"training")).agg(sum($"earnings"))
* }}}
*
* @param pivotColumn he column to pivot.
* @since 2.4.0
*/
Expand All @@ -384,6 +392,10 @@ class RelationalGroupedDataset protected[sql](
.sort(pivotColumn) // ensure that the output columns are in a consistent logical order
.collect()
.map(_.get(0))
.collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use "map"?

case row: GenericRow => struct(row.values.map(lit): _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this will not work for nested struct types, or say, multiple pivot columns with nested type. Could you please add a test like:

  test("pivoting column list") {
    val expected = ...
    val df = trainingSales
      .groupBy($"sales.year")
      .pivot(struct($"sales", $"training"))
      .agg(sum($"sales.earnings"))
     checkAnswer(df, expected)
  }

And can we also check if it works for other complex nested types, like Array(Struct(...))?

case value => lit(value)
}
.toSeq

if (values.length > maxValues) {
Expand All @@ -403,20 +415,29 @@ class RelationalGroupedDataset protected[sql](
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* df.groupBy($"year").pivot($"course", Seq(lit("dotNET"), lit("Java"))).sum($"earnings")
* }}}
*
* For pivoting by multiple columns, use the `struct` function to combine the columns and values:
*
* {{{
* df
* .groupBy($"year")
* .pivot(struct($"course", $"training"), Seq(struct(lit("java"), lit("Experts"))))
* .agg(sum($"earnings"))
* }}}
*
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 2.4.0
*/
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: Seq[Column]): RelationalGroupedDataset = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon I think this change is better than what #21699 did.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, wouldn't we better allow this Seq[Column] for both pivot(String ...) and pivot(Column ...) too by Seq[Any] since pivot(String ...)'s signature allows it?

BTW, we should document this in the param and describe the difference clearly in the documentation. Otherwise, seems the current API change makes the usage potentially quite confusing to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay if that's what you all guys think. It should really be clearly documented then now if we go ahead with the current way.

Copy link
Contributor

@maryannxue maryannxue Aug 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon
You can just consider pivot(String, Seq[Any]) as a simplified version of pivot(Column, Seq[Column]) for users who don't care to use multiple pivot columns or a pivot column of complex types. Given that now we have the full-functional version and the simple version here, I don't think adding another signature is necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I didn't mean to add another signature. My only worry is that pivot(String, Seq[Any]) can take actual values as well whereas pivot(Column, Seq[Column]) does not allow actual values, right?

I was thinking we should allow both cases for both APIs. Otherwise, it can be confusing, isn't it? These differences should really be clarified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Seq[Any] takes literal values (objects); Seq[Column] takes Column expressions.

I mean:
Before:

scala> val df = spark.range(10).selectExpr("struct(id) as a")
df: org.apache.spark.sql.DataFrame = [a: struct<id: bigint>]

scala> df.groupBy().pivot("a", Seq(struct(lit(1)))).count().show()
java.lang.RuntimeException: Unsupported literal type class org.apache.spark.sql.Column named_struct(col1, 1)
  at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78)
  at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$pivot$1.apply(RelationalGroupedDataset.scala:419)
  at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$pivot$1.apply(RelationalGroupedDataset.scala:419)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.immutable.List.foreach(List.scala:392)
  at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
  at scala.collection.immutable.List.map(List.scala:296)
  at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:419)
  at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:338)
  ... 51 elided

After:

scala> val df = spark.range(10).selectExpr("struct(id) as a")
df: org.apache.spark.sql.DataFrame = [a: struct<id: bigint>]

scala> df.groupBy().pivot("a", Seq(struct(lit(1)))).count().show()
+---+
|[1]|
+---+
|  1|
+---+

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @MaxGekk's intention was to keep the old signature as it is but somehow used "lit" which takes Column too. Correct me if I'm wrong, @MaxGekk.
So back to the choice between pivot(Column, Seq[Column]) and pivot(Column, Seq[Any]), I think having an explicit Seq[Column] type is less confusing and kind of tells people by itself that we are now support complex types in pivot values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think #22030 (comment) makes perfect sense. We really don't need to make it complicated.

having an explicit Seq[Column] type is less confusing and kind of tells people by itself that we are now support complex types in pivot values.

My question was that it's from your speculation or actual feedback from users since the original interface has existed for few years and I haven't seen some complaints about this so far as far as I can tell.

It's okay if we clearly document this with some examples. It wouldn't necessarily make some differences between same overloaded APIs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question was that it's from your speculation or actual feedback from users...

This is an actual feedback from our users who want to do pivoting by multiple columns. They have to use an external systems (even Microsoft Excel does it better) for pivoting by many columns for now because Spark doesn't allow this. You cannot express for example this on the latest release:

trainingSales
      .groupBy($"sales.year")
      .pivot(struct(lower($"sales.course"), $"training"), Seq(
        struct(lit("dotnet"), lit("Experts")),
        struct(lit("java"), lit("Dummies")))
      ).agg(sum($"sales.earnings"))

via def pivot(pivotColumn: String, values: Seq[Any]). I am not speaking about the recently added method def pivot(pivotColumn: Column, values: Seq[Any]) which we are going to make more concise and eliminate unnecessary generic type Any.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is if users really get confused or not by Any because I as a user used this heavily and I have been fine with that for long time. In that case, I thought we better keep it consistent with the original one.

groupType match {
case RelationalGroupedDataset.GroupByType =>
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values))
case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
Expand All @@ -433,7 +454,7 @@ class RelationalGroupedDataset protected[sql](
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 2.4.0
*/
def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: java.util.List[Column]): RelationalGroupedDataset = {
pivot(pivotColumn, values.asScala)
}

Expand Down Expand Up @@ -561,5 +582,5 @@ private[sql] object RelationalGroupedDataset {
/**
* To indicate it's the PIVOT
*/
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Column]) extends GroupType
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,22 @@ public void pivot() {
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

@Test
public void pivotColumnValues() {
Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
.pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java")))
.agg(sum("earnings")).orderBy("year").collectAsList();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 space indentation


Assert.assertEquals(2012, actual.get(0).getInt(0));
Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);

Assert.assertEquals(2013, actual.get(1).getInt(0));
Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}

private String getResource(String resource) {
try {
// The following "getResource" has different behaviors in SBT and Maven.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
.agg(sum($"earnings")),
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
courseSales.groupBy($"year").pivot($"course", Seq(lit("dotNET"), lit("Java")))
.agg(sum($"earnings")),
expected)
}
Expand All @@ -44,7 +44,10 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
expected)
checkAnswer(
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
courseSales
.groupBy('course)
.pivot('year, Seq(lit(2012), lit(2013)))
.agg(sum('earnings)),
expected)
}

Expand All @@ -58,7 +61,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
expected)
checkAnswer(
courseSales.groupBy($"year")
.pivot($"course", Seq("dotNET", "Java"))
.pivot($"course", Seq(lit("dotNET"), lit("Java")))
.agg(sum($"earnings"), avg($"earnings")),
expected)
}
Expand Down Expand Up @@ -204,7 +207,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
expected)
checkAnswer(
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
complexData.groupBy().pivot('b, Seq(lit(true), lit(false))).agg(max('a)),
expected)
}

Expand Down Expand Up @@ -272,7 +275,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase))
.pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase).map(lit))
.agg(sum($"sales.earnings"))

checkAnswer(df, expected)
Expand All @@ -282,7 +285,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET"))
.pivot(concat_ws("-", $"training", $"sales.course"), Seq(lit("Experts-dotNET")))
.agg(sum($"sales.earnings"))

checkAnswer(df, expected)
Expand All @@ -292,7 +295,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil
val df1 = trainingSales
.groupBy($"sales.year")
.pivot(lit(123), Seq(123))
.pivot(lit(123), Seq(lit(123)))
.agg(sum($"sales.earnings"))

checkAnswer(df1, expected)
Expand All @@ -302,10 +305,34 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
val exception = intercept[AnalysisException] {
trainingSales
.groupBy($"sales.year")
.pivot(min($"training"), Seq("Experts"))
.pivot(min($"training"), Seq(lit("Experts")))
.agg(sum($"sales.earnings"))
}

assert(exception.getMessage.contains("aggregate functions are not allowed"))
}

test("pivoting column list with values") {
val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"), Seq(
struct(lit("dotnet"), lit("Experts")),
struct(lit("java"), lit("Dummies")))
).agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}

test("pivoting column list") {
val expected = Seq(
Row(2012, 5000.0, 10000.0, null, 20000.0),
Row(2013, null, 48000.0, 30000.0, null))
val df = trainingSales
.groupBy($"sales.year")
.pivot(struct(lower($"sales.course"), $"training"))
.agg(sum($"sales.earnings"))

checkAnswer(df, expected)
}
}