From eb11a682ce35b92519dced7e00ee41dd4c95f4a7 Mon Sep 17 00:00:00 2001 From: Ivan Sadikov Date: Fri, 20 Jan 2023 17:25:22 +1300 Subject: [PATCH 1/2] add top implementation --- .../execution/datasources/jdbc/JDBCRDD.scala | 3 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 8 +++++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 5 ++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 29 +++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e23fe05a8a4fe..4a1a8321a8734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -307,11 +307,12 @@ private[jdbc] class JDBCRDD( "" } + val myTopExpression: String = dialect.getTopExpression(limit) // SQL Server Limit alternative val myLimitClause: String = dialect.getLimitClause(limit) val myOffsetClause: String = dialect.getOffsetClause(offset) val sqlText = options.prepareQuery + - s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + + s"SELECT $myTopExpression $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 230276e7100d2..5df6e7971d852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -544,6 +544,14 @@ abstract class JdbcDialect extends Serializable with Logging { if (limit > 0 ) s"LIMIT $limit" else "" } + /** + * MS SQL Server version of `getLimitClause`. + * This is only supported by SQL Server as it uses TOP (N) instead. + */ + def getTopExpression(limit: Integer): String = { + "" + } + /** * returns the OFFSET clause for the SELECT statement */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 625b3eef7fbc7..07289a39525d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -167,10 +167,15 @@ private object MsSqlServerDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } + // SQL Server does not support, it uses `getTopExpression` instead. override def getLimitClause(limit: Integer): String = { "" } + override def getTopExpression(limit: Integer): String = { + if (limit > 0) s"TOP ($limit)" else "" + } + override def classifyException(message: String, e: Throwable): AnalysisException = { e match { case sqlException: SQLException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3e317dc95476b..ac6579baf535f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1001,6 +1001,35 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } } + test("Dialect Limit and Top implementation") { + // Dialects that support LIMIT N. + val limitDialects = Seq( + JdbcDialects.get("jdbc:mysql"), + JdbcDialects.get("jdbc:postgresql"), + JdbcDialects.get("jdbc:db2"), + JdbcDialects.get("jdbc:h2") + ) + + for (dialect <- limitDialects) { + assert(dialect.getLimitClause(0) == "") + assert(dialect.getTopExpression(0) == "") + assert(dialect.getLimitClause(100) == "LIMIT 100") + assert(dialect.getTopExpression(100) == "") + } + + // Dialects that support TOP (N) + val topDialects = Seq( + JdbcDialects.get("jdbc:sqlserver") + ) + + for (dialect <- topDialects) { + assert(dialect.getLimitClause(0) == "") + assert(dialect.getTopExpression(0) == "") + assert(dialect.getLimitClause(100) == "") + assert(dialect.getTopExpression(100) == "TOP (100)") + } + } + test("table exists query by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") From c451c7b98c3a42abedbaabb7849f2d82e8bf9475 Mon Sep 17 00:00:00 2001 From: Ivan Sadikov Date: Mon, 23 Jan 2023 11:26:32 +1300 Subject: [PATCH 2/2] update test name --- .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ac6579baf535f..694275eca941a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1001,7 +1001,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } } - test("Dialect Limit and Top implementation") { + test("SPARK-42128: Dialect Limit and Top implementation") { // Dialects that support LIMIT N. val limitDialects = Seq( JdbcDialects.get("jdbc:mysql"),