Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
.set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.db2.pushDownAggregate", "true")
.set("spark.sql.catalog.db2.pushDownLimit", "true")

override def tablePreparation(connection: Connection): Unit = {
connection.prepareStatement(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
.set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.mssql.pushDownAggregate", "true")
.set("spark.sql.catalog.mssql.pushDownLimit", "true")

override val connectionTimeout = timeout(7.minutes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
.set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.mysql.pushDownAggregate", "true")
.set("spark.sql.catalog.mysql.pushDownLimit", "true")

override val connectionTimeout = timeout(7.minutes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
.set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.oracle.pushDownAggregate", "true")
.set("spark.sql.catalog.oracle.pushDownLimit", "true")

override val connectionTimeout = timeout(7.minutes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import org.apache.logging.log4j.Level

import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample, Sort}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.NullOrdering
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite
Expand Down Expand Up @@ -402,7 +403,49 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

protected def checkAggregateRemoved(df: DataFrame): Unit = {
private def checkSortRemoved(df: DataFrame): Unit = {
val sorts = df.queryExecution.optimizedPlan.collect {
case s: Sort => s
}
assert(sorts.isEmpty)
}

test("simple scan with LIMIT") {
val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1")
assert(limitPushed(df, 1))
val rows = df.collect()
assert(rows.length === 1)
assert(rows(0).getString(0) === "amy")
assert(rows(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
assert(rows(0).getDouble(2) === 1000d)
}

test("simple scan with top N") {
Seq(NullOrdering.values()).flatten.foreach { nullOrdering =>
val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY salary $nullOrdering LIMIT 1")
assert(limitPushed(df1, 1))
checkSortRemoved(df1)
val rows1 = df1.collect()
assert(rows1.length === 1)
assert(rows1(0).getString(0) === "cathy")
assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("9000.00"))
assert(rows1(0).getDouble(2) === 1200d)

val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY bonus DESC $nullOrdering LIMIT 1")
assert(limitPushed(df2, 1))
checkSortRemoved(df2)
val rows2 = df2.collect()
assert(rows2.length === 1)
assert(rows2(0).getString(0) === "david")
assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
assert(rows2(0).getDouble(2) === 1300d)
}
}

private def checkAggregateRemoved(df: DataFrame): Unit = {
val aggregates = df.queryExecution.optimizedPlan.collect {
case agg: Aggregate => agg
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,54 +178,6 @@ private[jdbc] class JDBCRDD(
*/
override def getPartitions: Array[Partition] = partitions

/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
private val columnList: String = if (columns.isEmpty) "1" else columns.mkString(",")

/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String = {
val dialect = JdbcDialects.get(url)
predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ")
}

/**
* A WHERE clause representing both `filters`, if any, and the current partition.
*/
private def getWhereClause(part: JDBCPartition): String = {
if (part.whereClause != null && filterWhereClause.length > 0) {
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
} else if (part.whereClause != null) {
"WHERE " + part.whereClause
} else if (filterWhereClause.length > 0) {
"WHERE " + filterWhereClause
} else {
""
}
}

/**
* A GROUP BY clause representing pushed-down grouping columns.
*/
private def getGroupByClause: String = {
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
// The GROUP BY columns should already be quoted by the caller side.
s"GROUP BY ${groupByColumns.get.mkString(", ")}"
} else {
""
}
}

private def getOrderByClause: String = {
if (sortOrders.nonEmpty) {
s" ORDER BY ${sortOrders.mkString(", ")}"
} else {
""
}
}

/**
* Runs the SQL query against the JDBC driver.
*
Expand Down Expand Up @@ -299,20 +251,23 @@ private[jdbc] class JDBCRDD(
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.

val myWhereClause = getWhereClause(part)
var builder = dialect
.getJdbcSQLQueryBuilder(options)
.withColumns(columns)
.withPredicates(predicates, part)
.withSortOrders(sortOrders)
.withLimit(limit)
.withOffset(offset)

val myTableSampleClause: String = if (sample.nonEmpty) {
JdbcDialects.get(url).getTableSample(sample.get)
} else {
""
groupByColumns.foreach { groupByKeys =>
builder = builder.withGroupByColumns(groupByKeys)
}

val myLimitClause: String = dialect.getLimitClause(limit)
val myOffsetClause: String = dialect.getOffsetClause(offset)
sample.foreach { tableSampleInfo =>
builder = builder.withTableSample(tableSampleInfo)
}

val sqlText = options.prepareQuery +
s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause"
val sqlText = builder.build()
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,19 +538,25 @@ abstract class JdbcDialect extends Serializable with Logging {
}

/**
* returns the LIMIT clause for the SELECT statement
* Returns the LIMIT clause for the SELECT statement
*/
def getLimitClause(limit: Integer): String = {
if (limit > 0 ) s"LIMIT $limit" else ""
}

/**
* returns the OFFSET clause for the SELECT statement
* Returns the OFFSET clause for the SELECT statement
*/
def getOffsetClause(offset: Integer): String = {
if (offset > 0 ) s"OFFSET $offset" else ""
}

/**
* Returns the SQL builder for the SELECT statement.
*/
def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder =
new JdbcSQLQueryBuilder(this, options)

def supportsTableSample: Boolean = false

def getTableSample(sample: TableSampleInfo): String =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo

/**
* The builder to build a single SELECT query.
*
* Note: All the `withXXX` methods will be invoked at most once. The invocation order does not
* matter, as all these clauses follow the natural SQL order: sample the table first, then filter,
* then group by, then sort, then offset, then limit.
*
* @since 3.5.0
*/
class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) {

/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
protected var columnList: String = "1"

/**
* A WHERE clause representing both `filters`, if any, and the current partition.
*/
protected var whereClause: String = ""

/**
* A GROUP BY clause representing pushed-down grouping columns.
*/
protected var groupByClause: String = ""

/**
* A ORDER BY clause representing pushed-down sort of top n.
*/
protected var orderByClause: String = ""

/**
* A LIMIT value representing pushed-down limit.
*/
protected var limit: Int = -1

/**
* A OFFSET value representing pushed-down offset.
*/
protected var offset: Int = -1

/**
* A table sample clause representing pushed-down table sample.
*/
protected var tableSampleClause: String = ""

/**
* The columns names that following dialect's SQL syntax.
* e.g. The column name is the raw name or quoted name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's the raw name string or quoted name?

From API's perspective, I think it's better to pass raw name string like a#b instead of 'a#b'. And the implementation should call dialect API to quote it. We can change it later as this PR will be master branch only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. It seems need more change.

*/
def withColumns(columns: Array[String]): JdbcSQLQueryBuilder = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add api doc for each with method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like if the column name is the raw name or quoted name following dialect's SQL syntax?

if (columns.nonEmpty) {
columnList = columns.mkString(",")
}
this
}

/**
* Constructs the WHERE clause that following dialect's SQL syntax.
*/
def withPredicates(predicates: Array[Predicate], part: JDBCPartition): JdbcSQLQueryBuilder = {
// `filters`, but as a WHERE clause suitable for injection into a SQL query.
val filterWhereClause: String = {
predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ")
}

// A WHERE clause representing both `filters`, if any, and the current partition.
whereClause = if (part.whereClause != null && filterWhereClause.length > 0) {
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
} else if (part.whereClause != null) {
"WHERE " + part.whereClause
} else if (filterWhereClause.length > 0) {
"WHERE " + filterWhereClause
} else {
""
}

this
}

/**
* Constructs the GROUP BY clause that following dialect's SQL syntax.
*/
def withGroupByColumns(groupByColumns: Array[String]): JdbcSQLQueryBuilder = {
if (groupByColumns.nonEmpty) {
// The GROUP BY columns should already be quoted by the caller side.
groupByClause = s"GROUP BY ${groupByColumns.mkString(", ")}"
}

this
}

/**
* Constructs the ORDER BY clause that following dialect's SQL syntax.
*/
def withSortOrders(sortOrders: Array[String]): JdbcSQLQueryBuilder = {
if (sortOrders.nonEmpty) {
orderByClause = s" ORDER BY ${sortOrders.mkString(", ")}"
}

this
}

/**
* Saves the limit value used to construct LIMIT clause.
*/
def withLimit(limit: Int): JdbcSQLQueryBuilder = {
this.limit = limit

this
}

/**
* Saves the offset value used to construct OFFSET clause.
*/
def withOffset(offset: Int): JdbcSQLQueryBuilder = {
this.offset = offset

this
}

/**
* Constructs the table sample clause that following dialect's SQL syntax.
*/
def withTableSample(sample: TableSampleInfo): JdbcSQLQueryBuilder = {
tableSampleClause = dialect.getTableSample(sample)

this
}

/**
* Build the final SQL query that following dialect's SQL syntax.
*/
def build(): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api doc here as well.

// Constructs the LIMIT clause that following dialect's SQL syntax.
val limitClause = dialect.getLimitClause(limit)
// Constructs the OFFSET clause that following dialect's SQL syntax.
val offsetClause = dialect.getOffsetClause(offset)

options.prepareQuery +
s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause $limitClause $offsetClause"
}
}
Loading