diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 1a25cd2802dd7..b4f832e7902f3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -62,6 +62,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.db2.pushDownAggregate", "true") + .set("spark.sql.catalog.db2.pushDownLimit", "true") override def tablePreparation(connection: Connection): Unit = { connection.prepareStatement( diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index a527c6f8cb5b6..6b8d62f8f7b1d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -59,6 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.mssql.pushDownAggregate", "true") + .set("spark.sql.catalog.mssql.pushDownLimit", "true") override val connectionTimeout = timeout(7.minutes) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 072fdbb3f3424..97f9843b9ce5e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -55,6 +55,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest .set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.mysql.pushDownAggregate", "true") + .set("spark.sql.catalog.mysql.pushDownLimit", "true") override val connectionTimeout = timeout(7.minutes) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 5de7608918852..1f8e55d04f1ce 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -76,6 +76,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.oracle.pushDownAggregate", "true") + .set("spark.sql.catalog.oracle.pushDownLimit", "true") override val connectionTimeout = timeout(7.minutes) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index f16d9b507d5f2..323aeab477ec5 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -21,10 +21,11 @@ import org.apache.logging.log4j.Level import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample, Sort} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.NullOrdering import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite @@ -402,7 +403,49 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def checkAggregateRemoved(df: DataFrame): Unit = { + private def checkSortRemoved(df: DataFrame): Unit = { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + assert(sorts.isEmpty) + } + + test("simple scan with LIMIT") { + val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1") + assert(limitPushed(df, 1)) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "amy") + assert(rows(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) + assert(rows(0).getDouble(2) === 1000d) + } + + test("simple scan with top N") { + Seq(NullOrdering.values()).flatten.foreach { nullOrdering => + val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 ORDER BY salary $nullOrdering LIMIT 1") + assert(limitPushed(df1, 1)) + checkSortRemoved(df1) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "cathy") + assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("9000.00")) + assert(rows1(0).getDouble(2) === 1200d) + + val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 ORDER BY bonus DESC $nullOrdering LIMIT 1") + assert(limitPushed(df2, 1)) + checkSortRemoved(df2) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "david") + assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) + assert(rows2(0).getDouble(2) === 1300d) + } + } + + private def checkAggregateRemoved(df: DataFrame): Unit = { val aggregates = df.queryExecution.optimizedPlan.collect { case agg: Aggregate => agg } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e23fe05a8a4fe..de7dfeab643f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -178,54 +178,6 @@ private[jdbc] class JDBCRDD( */ override def getPartitions: Array[Partition] = partitions - /** - * `columns`, but as a String suitable for injection into a SQL query. - */ - private val columnList: String = if (columns.isEmpty) "1" else columns.mkString(",") - - /** - * `filters`, but as a WHERE clause suitable for injection into a SQL query. - */ - private val filterWhereClause: String = { - val dialect = JdbcDialects.get(url) - predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ") - } - - /** - * A WHERE clause representing both `filters`, if any, and the current partition. - */ - private def getWhereClause(part: JDBCPartition): String = { - if (part.whereClause != null && filterWhereClause.length > 0) { - "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" - } else if (part.whereClause != null) { - "WHERE " + part.whereClause - } else if (filterWhereClause.length > 0) { - "WHERE " + filterWhereClause - } else { - "" - } - } - - /** - * A GROUP BY clause representing pushed-down grouping columns. - */ - private def getGroupByClause: String = { - if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) { - // The GROUP BY columns should already be quoted by the caller side. - s"GROUP BY ${groupByColumns.get.mkString(", ")}" - } else { - "" - } - } - - private def getOrderByClause: String = { - if (sortOrders.nonEmpty) { - s" ORDER BY ${sortOrders.mkString(", ")}" - } else { - "" - } - } - /** * Runs the SQL query against the JDBC driver. * @@ -299,20 +251,23 @@ private[jdbc] class JDBCRDD( // fully-qualified table name in the SELECT statement. I don't know how to // talk about a table in a completely portable way. - val myWhereClause = getWhereClause(part) + var builder = dialect + .getJdbcSQLQueryBuilder(options) + .withColumns(columns) + .withPredicates(predicates, part) + .withSortOrders(sortOrders) + .withLimit(limit) + .withOffset(offset) - val myTableSampleClause: String = if (sample.nonEmpty) { - JdbcDialects.get(url).getTableSample(sample.get) - } else { - "" + groupByColumns.foreach { groupByKeys => + builder = builder.withGroupByColumns(groupByKeys) } - val myLimitClause: String = dialect.getLimitClause(limit) - val myOffsetClause: String = dialect.getOffsetClause(offset) + sample.foreach { tableSampleInfo => + builder = builder.withTableSample(tableSampleInfo) + } - val sqlText = options.prepareQuery + - s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause" + val sqlText = builder.build() stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 230276e7100d2..2e9477356e6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -538,19 +538,25 @@ abstract class JdbcDialect extends Serializable with Logging { } /** - * returns the LIMIT clause for the SELECT statement + * Returns the LIMIT clause for the SELECT statement */ def getLimitClause(limit: Integer): String = { if (limit > 0 ) s"LIMIT $limit" else "" } /** - * returns the OFFSET clause for the SELECT statement + * Returns the OFFSET clause for the SELECT statement */ def getOffsetClause(offset: Integer): String = { if (offset > 0 ) s"OFFSET $offset" else "" } + /** + * Returns the SQL builder for the SELECT statement. + */ + def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder = + new JdbcSQLQueryBuilder(this, options) + def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala new file mode 100644 index 0000000000000..6113da3d4e8f4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo + +/** + * The builder to build a single SELECT query. + * + * Note: All the `withXXX` methods will be invoked at most once. The invocation order does not + * matter, as all these clauses follow the natural SQL order: sample the table first, then filter, + * then group by, then sort, then offset, then limit. + * + * @since 3.5.0 + */ +class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) { + + /** + * `columns`, but as a String suitable for injection into a SQL query. + */ + protected var columnList: String = "1" + + /** + * A WHERE clause representing both `filters`, if any, and the current partition. + */ + protected var whereClause: String = "" + + /** + * A GROUP BY clause representing pushed-down grouping columns. + */ + protected var groupByClause: String = "" + + /** + * A ORDER BY clause representing pushed-down sort of top n. + */ + protected var orderByClause: String = "" + + /** + * A LIMIT value representing pushed-down limit. + */ + protected var limit: Int = -1 + + /** + * A OFFSET value representing pushed-down offset. + */ + protected var offset: Int = -1 + + /** + * A table sample clause representing pushed-down table sample. + */ + protected var tableSampleClause: String = "" + + /** + * The columns names that following dialect's SQL syntax. + * e.g. The column name is the raw name or quoted name. + */ + def withColumns(columns: Array[String]): JdbcSQLQueryBuilder = { + if (columns.nonEmpty) { + columnList = columns.mkString(",") + } + this + } + + /** + * Constructs the WHERE clause that following dialect's SQL syntax. + */ + def withPredicates(predicates: Array[Predicate], part: JDBCPartition): JdbcSQLQueryBuilder = { + // `filters`, but as a WHERE clause suitable for injection into a SQL query. + val filterWhereClause: String = { + predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ") + } + + // A WHERE clause representing both `filters`, if any, and the current partition. + whereClause = if (part.whereClause != null && filterWhereClause.length > 0) { + "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" + } else if (part.whereClause != null) { + "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause + } else { + "" + } + + this + } + + /** + * Constructs the GROUP BY clause that following dialect's SQL syntax. + */ + def withGroupByColumns(groupByColumns: Array[String]): JdbcSQLQueryBuilder = { + if (groupByColumns.nonEmpty) { + // The GROUP BY columns should already be quoted by the caller side. + groupByClause = s"GROUP BY ${groupByColumns.mkString(", ")}" + } + + this + } + + /** + * Constructs the ORDER BY clause that following dialect's SQL syntax. + */ + def withSortOrders(sortOrders: Array[String]): JdbcSQLQueryBuilder = { + if (sortOrders.nonEmpty) { + orderByClause = s" ORDER BY ${sortOrders.mkString(", ")}" + } + + this + } + + /** + * Saves the limit value used to construct LIMIT clause. + */ + def withLimit(limit: Int): JdbcSQLQueryBuilder = { + this.limit = limit + + this + } + + /** + * Saves the offset value used to construct OFFSET clause. + */ + def withOffset(offset: Int): JdbcSQLQueryBuilder = { + this.offset = offset + + this + } + + /** + * Constructs the table sample clause that following dialect's SQL syntax. + */ + def withTableSample(sample: TableSampleInfo): JdbcSQLQueryBuilder = { + tableSampleClause = dialect.getTableSample(sample) + + this + } + + /** + * Build the final SQL query that following dialect's SQL syntax. + */ + def build(): String = { + // Constructs the LIMIT clause that following dialect's SQL syntax. + val limitClause = dialect.getLimitClause(limit) + // Constructs the OFFSET clause that following dialect's SQL syntax. + val offsetClause = dialect.getOffsetClause(offset) + + options.prepareQuery + + s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" + + s" $whereClause $groupByClause $orderByClause $limitClause $offsetClause" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 625b3eef7fbc7..39b617135ce36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -24,8 +24,9 @@ import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -63,6 +64,20 @@ private object MsSqlServerDialect extends JdbcDialect { supportedFunctions.contains(funcName) class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override def visitSortOrder( + sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { + (sortDirection, nullOrdering) match { + case (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) => + s"$sortKey $sortDirection" + case (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) => + s"CASE WHEN $sortKey IS NULL THEN 1 ELSE 0 END, $sortKey $sortDirection" + case (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) => + s"CASE WHEN $sortKey IS NULL THEN 0 ELSE 1 END, $sortKey $sortDirection" + case (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) => + s"$sortKey $sortDirection" + } + } + override def dialectFunctionName(funcName: String): String = funcName match { case "VAR_POP" => "VARP" case "VAR_SAMP" => "VAR" @@ -168,7 +183,7 @@ private object MsSqlServerDialect extends JdbcDialect { } override def getLimitClause(limit: Integer): String = { - "" + if (limit > 0) s"TOP $limit" else "" } override def classifyException(message: String, e: Throwable): AnalysisException = { @@ -181,4 +196,20 @@ private object MsSqlServerDialect extends JdbcDialect { case _ => super.classifyException(message, e) } } + + class MsSqlServerSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) + extends JdbcSQLQueryBuilder(dialect, options) { + + // TODO[SPARK-42289]: DS V2 pushdown could let JDBC dialect decide to push down offset + override def build(): String = { + val limitClause = dialect.getLimitClause(limit) + + options.prepareQuery + + s"SELECT $limitClause $columnList FROM ${options.tableOrQuery} $tableSampleClause" + + s" $whereClause $groupByClause $orderByClause" + } + } + + override def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder = + new MsSqlServerSQLQueryBuilder(this, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 12882dc8e676b..1f615ed76c5b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference, NullOrdering, SortDirection} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -51,6 +51,20 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { supportedFunctions.contains(funcName) class MySQLSQLBuilder extends JDBCSQLBuilder { + override def visitSortOrder( + sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { + (sortDirection, nullOrdering) match { + case (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) => + s"$sortKey $sortDirection" + case (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) => + s"CASE WHEN $sortKey IS NULL THEN 1 ELSE 0 END, $sortKey $sortDirection" + case (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) => + s"CASE WHEN $sortKey IS NULL THEN 0 ELSE 1 END, $sortKey $sortDirection" + case (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) => + s"$sortKey $sortDirection" + } + } + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 79ac248d723e3..d0e925bad3848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -173,4 +174,29 @@ private case object OracleDialect extends JdbcDialect { val nullable = if (isNullable) "NULL" else "NOT NULL" s"ALTER TABLE $tableName MODIFY ${quoteIdentifier(columnName)} $nullable" } + + override def getLimitClause(limit: Integer): String = { + // Oracle doesn't support LIMIT clause. + // We can use rownum <= n to limit the number of rows in the result set. + if (limit > 0) s"WHERE rownum <= $limit" else "" + } + + class OracleSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) + extends JdbcSQLQueryBuilder(dialect, options) { + + // TODO[SPARK-42289]: DS V2 pushdown could let JDBC dialect decide to push down offset + override def build(): String = { + val selectStmt = s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" + + s" $whereClause $groupByClause $orderByClause" + if (limit > 0) { + val limitClause = dialect.getLimitClause(limit) + options.prepareQuery + s"SELECT tab.* FROM ($selectStmt) tab $limitClause" + } else { + options.prepareQuery + selectStmt + } + } + } + + override def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder = + new OracleSQLQueryBuilder(this, options) }