Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ package object dsl {
}
}

def rand(e: Long): Expression = Rand(Literal.create(e, LongType))
Copy link
Member

@viirya viirya Jun 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just use Rand(seed: Long). See object Rand in randomExpressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have a bunch of expressions here, I don't think it would hurt to add this one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean: def rand(e: Long): Expression = Rand(e).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed the comment when I merged the code.

def sum(e: Expression): Expression = Sum(e).toAggregateExpression()
def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true)
def count(e: Expression): Expression = Count(e).toAggregateExpression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] {

/**
* The Project before Filter is not necessary but conflict with PushPredicatesThroughProject,
* so remove it.
* so remove it. Since the Projects have been added top-down, we need to remove in bottom-up
* order, otherwise lower Projects can be missed.
*/
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform {
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
if p2.outputSet.subsetOf(child.outputSet) =>
p1.copy(child = f.copy(child = child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized2, expected2.analyze)
}

test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
val input = LocalRelation('key.int, 'value.string)
val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze
val optimized = Optimize.execute(query)
val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
comparePlans(optimized, expected)
}

// todo: add more tests for column pruning
}
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-24696 ColumnPruning rule fails to remove extra Project") {
Copy link
Member

@viirya viirya Jun 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test in Jira is simpler than this. Do we need to have two tables and a join? Why not just use the test in Jira?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new unit test in ColumnPruningSuite.scala already covers that.

withTable("fact_stats", "dim_stats") {
val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4))
val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US"))
spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic)
factData.toDF("date_id", "store_id", "product_id", "units_sold")
.write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats")
storeData.toDF("store_id", "state_province", "country")
.write.mode("overwrite").format("parquet").saveAsTable("dim_stats")
val df = sql(
"""
|SELECT f.date_id, f.product_id, f.store_id FROM
|(SELECT date_id, product_id, store_id
| FROM fact_stats WHERE filterND(date_id)) AS f
|JOIN dim_stats s
|ON f.store_id = s.store_id WHERE s.country = 'DE'
""".stripMargin)
checkAnswer(df, Seq(Row(3, 99, 1)))
}
}
}