Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ object JDBCRDD extends Logging {
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param requiredColumns - The names of the columns to SELECT.
* @param requiredColumns - The names of the columns or aggregate columns to SELECT.
* @param filters - The filters to include in all WHERE clauses.
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns to SELECT.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
Expand Down Expand Up @@ -213,8 +213,8 @@ object JDBCRDD extends Logging {
}

/**
* An RDD representing a table in a database accessed via JDBC. Both the
* driver code and the workers must be able to access the database; the driver
* An RDD representing a query is related to a table in a database accessed via JDBC.
* Both the driver code and the workers must be able to access the database; the driver
* needs to fetch the schema while the workers need to fetch the data.
*/
private[jdbc] class JDBCRDD(
Expand All @@ -237,11 +237,7 @@ private[jdbc] class JDBCRDD(
/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
private val columnList: String = {
val sb = new StringBuilder()
columns.foreach(x => sb.append(",").append(x))
if (sb.isEmpty) "1" else sb.substring(1)
}
private val columnList: String = if (columns.isEmpty) "1" else columns.mkString(",")

/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,18 @@ private[sql] case class JDBCRelation(
}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// When pushDownPredicate is false, all Filters that need to be pushed down should be ignored
val pushedFilters = if (jdbcOptions.pushDownPredicate) {
filters
} else {
Array.empty[Filter]
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pushDownPredicate is false, the unhandledFilters are set to all the filters here https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala#L276. Seems to me that the unhandledFilters shouldn't be pushed down to JDBC at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's a bug since day 1: #21875

I think we should update the tested add at that time and check the real pushed filters in JDBC source.

// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
filters,
pushedFilters,
parts,
jdbcOptions).asInstanceOf[RDD[Row]]
}
Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,40 @@ class JDBCSuite extends QueryTest
Row("fred", 1) :: Nil)
}

test(
"SPARK-36574: pushDownPredicate=false should prevent push down filters to JDBC data source") {
val df = spark.read.format("jdbc")
.option("Url", urlWithUserAndPass)
.option("dbTable", "test.people")
val df1 = df
.option("pushDownPredicate", false)
.load()
.filter("theid = 1")
.select("name", "theid")
val df2 = df
.option("pushDownPredicate", true)
.load()
.filter("theid = 1")
.select("name", "theid")
val df3 = df
.load()
.select("name", "theid")

def getRowCount(df: DataFrame): Long = {
val queryExecution = df.queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: DataSourceScanExec => p
} match {
case Seq(p) => p
case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
}
rawPlan.execute().count()
}

assert(getRowCount(df1) == df3.count)
assert(getRowCount(df2) < df3.count)
}

test("SPARK-26383 throw IllegalArgumentException if wrong kind of driver to the given url") {
val e = intercept[IllegalArgumentException] {
val opts = Map(
Expand Down