Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,10 @@ class Analyzer(
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
|| !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
throw new AnalysisException(
s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.")
Copy link
Member

@gatorsmile gatorsmile Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the other reviewers: this is consistent with the requirements of group-by columns.

}
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
// Check all pivot values are literal and match pivot column data type.
Expand Down Expand Up @@ -574,10 +578,14 @@ class Analyzer(
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
// alternate plan that instead uses two steps of aggregation.
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
val bigGroup = groupByExprs ++ pivotColumn.references
val namedPivotCol = pivotColumn match {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to revert the original walk-around aimed to avoid the PivotFirst issue. Now that we have PivotFirst working alright for complex types, we can revert it.

case n: NamedExpression => n
case _ => Alias(pivotColumn, "__pivot_col")()
}
val bigGroup = groupByExprs :+ namedPivotCol
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
val pivotAggs = namedAggExps.map { a =>
Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues)
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import scala.collection.immutable.HashMap
import scala.collection.immutable.{HashMap, TreeMap}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._

object PivotFirst {
Expand Down Expand Up @@ -83,7 +83,12 @@ case class PivotFirst(

override val dataType: DataType = ArrayType(valueDataType)

val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*)
val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) {
HashMap(pivotColumnValues.zipWithIndex: _*)
} else {
TreeMap(pivotColumnValues.zipWithIndex: _*)(
TypeUtils.getInterpretedOrdering(pivotColumn.dataType))
}

val indexSize = pivotIndex.size

Expand Down
78 changes: 72 additions & 6 deletions sql/core/src/test/resources/sql-tests/inputs/pivot.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ create temporary view years as select * from values
(2013, 2)
as years(y, s);

create temporary view yearsWithArray as select * from values
(2012, array(1, 1)),
(2013, array(2, 2))
as yearsWithArray(y, a);
create temporary view yearsWithComplexTypes as select * from values
(2012, array(1, 1), map('1', 1), struct(1, 'a')),
(2013, array(2, 2), map('2', 2), struct(2, 'b'))
as yearsWithComplexTypes(y, a, m, s);

-- pivot courses
SELECT * FROM (
Expand Down Expand Up @@ -204,7 +204,7 @@ PIVOT (
SELECT * FROM (
SELECT course, year, a
FROM courseSales
JOIN yearsWithArray ON year = y
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
min(a)
Expand All @@ -215,9 +215,75 @@ PIVOT (
SELECT * FROM (
SELECT course, year, y, a
FROM courseSales
JOIN yearsWithArray ON year = y
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
max(a)
FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java'))
);

-- pivot on pivot column of array type
SELECT * FROM (
SELECT earnings, year, a
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR a IN (array(1, 1), array(2, 2))
);

-- pivot on multiple pivot columns containing array type
SELECT * FROM (
SELECT course, earnings, year, a
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2)))
);

-- pivot on pivot column of struct type
SELECT * FROM (
SELECT earnings, year, s
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR s IN ((1, 'a'), (2, 'b'))
);

-- pivot on multiple pivot columns containing struct type
SELECT * FROM (
SELECT course, earnings, year, s
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b')))
);

-- pivot on pivot column of map type
SELECT * FROM (
SELECT earnings, year, m
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR m IN (map('1', 1), map('2', 2))
);

-- pivot on multiple pivot columns containing map type
SELECT * FROM (
SELECT course, earnings, year, m
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2)))
);
116 changes: 109 additions & 7 deletions sql/core/src/test/resources/sql-tests/results/pivot.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 25
-- Number of queries: 31


-- !query 0
Expand Down Expand Up @@ -28,10 +28,10 @@ struct<>


-- !query 2
create temporary view yearsWithArray as select * from values
(2012, array(1, 1)),
(2013, array(2, 2))
as yearsWithArray(y, a)
create temporary view yearsWithComplexTypes as select * from values
(2012, array(1, 1), map('1', 1), struct(1, 'a')),
(2013, array(2, 2), map('2', 2), struct(2, 'b'))
as yearsWithComplexTypes(y, a, m, s)
-- !query 2 schema
struct<>
-- !query 2 output
Expand Down Expand Up @@ -346,7 +346,7 @@ Literal expressions required for pivot values, found 'course#x';
SELECT * FROM (
SELECT course, year, a
FROM courseSales
JOIN yearsWithArray ON year = y
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
min(a)
Expand All @@ -363,7 +363,7 @@ struct<year:int,dotNET:array<int>,Java:array<int>>
SELECT * FROM (
SELECT course, year, y, a
FROM courseSales
JOIN yearsWithArray ON year = y
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
max(a)
Expand All @@ -374,3 +374,105 @@ struct<year:int,[2012, dotNET]:array<int>,[2013, Java]:array<int>>
-- !query 24 output
2012 [1,1] NULL
2013 NULL [2,2]


-- !query 25
SELECT * FROM (
SELECT earnings, year, a
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR a IN (array(1, 1), array(2, 2))
)
-- !query 25 schema
struct<year:int,[1, 1]:bigint,[2, 2]:bigint>
-- !query 25 output
2012 35000 NULL
2013 NULL 78000


-- !query 26
SELECT * FROM (
SELECT course, earnings, year, a
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2)))
)
-- !query 26 schema
struct<year:int,[dotNET, [1, 1]]:bigint,[Java, [2, 2]]:bigint>
-- !query 26 output
2012 15000 NULL
2013 NULL 30000


-- !query 27
SELECT * FROM (
SELECT earnings, year, s
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR s IN ((1, 'a'), (2, 'b'))
)
-- !query 27 schema
struct<year:int,[1, a]:bigint,[2, b]:bigint>
-- !query 27 output
2012 35000 NULL
2013 NULL 78000


-- !query 28
SELECT * FROM (
SELECT course, earnings, year, s
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b')))
)
-- !query 28 schema
struct<year:int,[dotNET, [1, a]]:bigint,[Java, [2, b]]:bigint>
-- !query 28 output
2012 15000 NULL
2013 NULL 30000


-- !query 29
SELECT * FROM (
SELECT earnings, year, m
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR m IN (map('1', 1), map('2', 2))
)
-- !query 29 schema
struct<>
-- !query 29 output
org.apache.spark.sql.AnalysisException
Invalid pivot column 'm#x'. Pivot columns must be comparable.;


-- !query 30
SELECT * FROM (
SELECT course, earnings, year, m
FROM courseSales
JOIN yearsWithComplexTypes ON year = y
)
PIVOT (
sum(earnings)
FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2)))
)
-- !query 30 schema
struct<>
-- !query 30 output
org.apache.spark.sql.AnalysisException
Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.;