Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
revert and update
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Apr 29, 2021
1 parent 1d4a5ff commit e7866dc
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,7 @@ class ColumnarHashAggregation(
aggregateAttr.toList
}

def existsAttrNotFound(allAggregateResultAttributes: List[Attribute])
: (Boolean, ListBuffer[Attribute], ListBuffer[Attribute]) = {
var notFoundInInput = false
var notFoundAttr = new ListBuffer[Attribute]()
var foundAttr = new ListBuffer[Attribute]()
def existsAttrNotFound(allAggregateResultAttributes: List[Attribute]): Unit = {
if (resultExpressions.size == allAggregateResultAttributes.size) {
var resAllAttr = true
breakable {
Expand All @@ -363,34 +359,12 @@ class ColumnarHashAggregation(
for (attr <- resultExpressions) {
if (allAggregateResultAttributes
.indexOf(attr.asInstanceOf[AttributeReference]) == -1) {
notFoundInInput = true
notFoundAttr += attr.asInstanceOf[AttributeReference]
} else {
foundAttr += attr.asInstanceOf[AttributeReference]
throw new IllegalArgumentException(
s"$attr in resultExpressions is not found in allAggregateResultAttributes!")
}
}
}
}
(notFoundInInput, notFoundAttr, foundAttr)
}

def getNewInputAttr(allAggregateResultAttributes: List[Attribute],
notFoundAttr: ListBuffer[Attribute],
foundAttr: ListBuffer[Attribute]): List[Attribute] = {
// This function replace the unfound attributes to those from result expressions.
for (attr <- notFoundAttr) {
for (inputAttr <- allAggregateResultAttributes) {
if (attr.name.split('#')(0) == inputAttr.name.split('#')(0)) {
foundAttr += attr
}
}
}
// If any attribute is still not found, an exception will be thrown.
if (existsAttrNotFound(foundAttr.toList)._1) {
throw new IllegalArgumentException(s"Attribute in resultExpressions " +
s"${resultExpressions} can't be found in input attributes: $foundAttr")
}
foundAttr.toList
}

def prepareKernelFunction: TreeNode = {
Expand Down Expand Up @@ -462,7 +436,7 @@ class ColumnarHashAggregation(
groupingAttributes.toList ::: getAttrForAggregateExpr(
aggregateExpressions,
aggregateAttributes)
var aggregateAttributeFieldList =
val aggregateAttributeFieldList =
allAggregateResultAttributes.map(attr => {
Field
.nullable(
Expand All @@ -471,21 +445,8 @@ class ColumnarHashAggregation(
})

// If some Attributes in result expressions (contain attributes only) are not found
// in allAggregateResultAttributes, a new attribute list will be created
// or an exception will be thrown.
val (notFound, notFoundAttr, foundAttr) =
existsAttrNotFound(allAggregateResultAttributes)
if (notFound) {
val newResAttrList =
getNewInputAttr(allAggregateResultAttributes, notFoundAttr, foundAttr)
aggregateAttributeFieldList =
newResAttrList.map(attr => {
Field
.nullable(
s"${attr.name}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})
}
// in allAggregateResultAttributes, an exception will be thrown.
existsAttrNotFound(allAggregateResultAttributes)

val nativeFuncNodes = groupingNativeFuncNodes ::: aggrNativeFuncNodes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ class DataFrameAggregateSuite extends QueryTest
}

Seq(true, false).foreach { value =>
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
ignore(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TravisDataFrameAggregateSuite extends QueryTest

val absTol = 1e-8

ignore("groupBy") {
test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
Expand Down Expand Up @@ -127,7 +127,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
val df = Seq(("some[thing]", "random-string")).toDF("key", "val")

checkAnswer(
Expand All @@ -149,7 +149,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("cube") {
test("cube") {
checkAnswer(
courseSales.cube("course", "year").sum("earnings"),
Row("Java", 2012, 20000.0) ::
Expand All @@ -173,7 +173,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
assert(cube0.where("date IS NULL").count > 0)
}

ignore("grouping and grouping_id") {
test("grouping and grouping_id") {
checkAnswer(
courseSales.cube("course", "year")
.agg(grouping("course"), grouping("year"), grouping_id("course", "year")),
Expand Down Expand Up @@ -211,7 +211,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
}
}

ignore("grouping/grouping_id inside window function") {
test("grouping/grouping_id inside window function") {

val w = Window.orderBy(sum("earnings"))
checkAnswer(
Expand All @@ -231,7 +231,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
checkAnswer(
courseSales.cube("course", "year")
.agg(grouping("CouRse"), grouping("year")),
Expand Down Expand Up @@ -302,14 +302,14 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("agg without groups and functions") {
test("agg without groups and functions") {
checkAnswer(
testData2.agg(lit(1)),
Row(1)
)
}

ignore("average") {
test("average") {
checkAnswer(
testData2.agg(avg($"a"), mean($"a")),
Row(2.0, 2.0))
Expand Down Expand Up @@ -350,7 +350,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
Row(2.0, 2.0))
}

ignore("zero average") {
test("zero average") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(avg($"a")),
Expand All @@ -369,7 +369,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
Row(6, 6.0))
}

ignore("null count") {
test("null count") {
checkAnswer(
testData3.groupBy($"a").agg(count($"b")),
Seq(Row(1, 0), Row(2, 1))
Expand All @@ -392,7 +392,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("multiple column distinct count") {
test("multiple column distinct count") {
val df1 = Seq(
("a", "b", "c"),
("a", "b", "c"),
Expand All @@ -417,14 +417,14 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("zero count") {
test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(count($"a"), sumDistinct($"a")), // non-partial
Row(0, null))
}

ignore("stddev") {
test("stddev") {
val testData2ADev = math.sqrt(4.0 / 5.0)
checkAnswer(
testData2.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")),
Expand All @@ -434,28 +434,28 @@ class TravisDataFrameAggregateSuite extends QueryTest
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
}

ignore("zero stddev") {
test("zero stddev") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")),
Row(null, null, null))
emptyTableData.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")),
Row(null, null, null))
}

ignore("zero sum") {
test("zero sum") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sum($"a")),
Row(null))
}

ignore("zero sum distinct") {
test("zero sum distinct") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sumDistinct($"a")),
Row(null))
}

ignore("moments") {
test("moments") {

val sparkVariance = testData2.agg(variance($"a"))
checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
Expand All @@ -473,7 +473,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol)
}

ignore("zero moments") {
test("zero moments") {
val input = Seq((1, 2)).toDF("a", "b")
checkAnswer(
input.agg(stddev($"a"), stddev_samp($"a"), stddev_pop($"a"), variance($"a"),
Expand All @@ -495,7 +495,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
Double.NaN, Double.NaN))
}

ignore("null moments") {
test("null moments") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(emptyTableData.agg(
variance($"a"), var_samp($"a"), var_pop($"a"), skewness($"a"), kurtosis($"a")),
Expand Down Expand Up @@ -547,7 +547,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
)
}

ignore("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
val bytesTest1 = "test1".getBytes
val bytesTest2 = "test2".getBytes
val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a")
Expand Down Expand Up @@ -593,7 +593,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
Seq(Row(Seq(1.0, 2.0))))
}

ignore("SPARK-14664: Decimal sum/avg over window should work.") {
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil)
Expand All @@ -602,7 +602,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
}

ignore("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") {
test("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") {
checkAnswer(
decimalData.groupBy($"a" cast DecimalType(10, 2)).agg(avg($"b" cast DecimalType(10, 2))),
Seq(Row(new java.math.BigDecimal(1), new java.math.BigDecimal("1.5")),
Expand All @@ -626,7 +626,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
limit2Df.select($"id"))
}

ignore("SPARK-17237 remove backticks in a pivot result schema") {
test("SPARK-17237 remove backticks in a pivot result schema") {
val df = Seq((2, 3, 4), (3, 4, 5)).toDF("a", "x", "y")
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
checkAnswer(
Expand All @@ -645,7 +645,7 @@ class TravisDataFrameAggregateSuite extends QueryTest

private def assertNoExceptions(c: Column): Unit = {
for ((wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))) {
Seq((true, true), (true, false), (false, true), (false, false))) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
Expand Down Expand Up @@ -679,7 +679,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
}
}

ignore("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
Expand Down Expand Up @@ -732,8 +732,6 @@ class TravisDataFrameAggregateSuite extends QueryTest
}
}

//TODO: failed ut
/*
testWithWholeStageCodegenOnAndOff("SPARK-22951: dropDuplicates on empty dataFrames " +
"should produce correct aggregate") { _ =>
// explicit global aggregations
Expand All @@ -748,7 +746,6 @@ class TravisDataFrameAggregateSuite extends QueryTest
// global aggregation is converted to grouping aggregation:
assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
}
*/

test("SPARK-21896: Window functions inside aggregate functions") {
def checkWindowError(df: => DataFrame): Unit = {
Expand Down Expand Up @@ -790,7 +787,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
"type: GroupBy]"))
}

ignore("SPARK-26021: NaN and -0.0 in grouping expressions") {
test("SPARK-26021: NaN and -0.0 in grouping expressions") {
checkAnswer(
Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(),
Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil)
Expand Down Expand Up @@ -842,7 +839,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
checkAnswer(countAndDistinct, Row(100000, 100))
}

ignore("max_by") {
test("max_by") {
val yearOfMaxEarnings =
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)
Expand Down Expand Up @@ -898,7 +895,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
}
}

ignore("min_by") {
test("min_by") {
val yearOfMinEarnings =
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)
Expand Down Expand Up @@ -954,7 +951,7 @@ class TravisDataFrameAggregateSuite extends QueryTest
}
}

ignore("count_if") {
test("count_if") {
withTempView("tempView") {
Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)),
("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6)))
Expand Down
Loading

0 comments on commit e7866dc

Please sign in to comment.