diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index d884ad4c6246..fd7efb1efb76 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker import org.apache.spark.sql.types._ @@ -37,6 +41,17 @@ import org.apache.spark.tags.DockerTest @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + def getExternalEngineQuery(executedPlan: SparkPlan): String = { + getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery + } + + def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = { + val queryNode = executedPlan.collect { case r: RowDataSourceScanExec => + r + }.head + queryNode.rdd + } + override def excluded: Seq[String] = Seq( "simple scan with OFFSET", "simple scan with LIMIT and OFFSET", @@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD |""".stripMargin) assert(df.collect().length == 2) } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END + | ELSE (name = 'Sauron') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END + | ELSE 'Sauron' END = name + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ + ) + // scalastyle:on + df.collect() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 61a26d7a4fbd..b0ce2bb4293e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -221,8 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_)) - val elseExprOpt = elseValue.flatMap(generateExpression(_)) + val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate)) + val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate)) if (conditions.length == branches.length && values.length == branches.length && elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => @@ -421,7 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L children: Seq[Expression], dataType: DataType, isPredicate: Boolean): Option[V2Expression] = { - val childrenExpressions = children.flatMap(generateExpression(_)) + val childrenExpressions = children.flatMap(generateExpression(_, isPredicate)) if (childrenExpressions.length == children.length) { if (isPredicate && dataType.isInstanceOf[BooleanType]) { Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 3bf1390cb664..81ad1a6d38bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils} @@ -377,6 +378,18 @@ abstract class JdbcDialect extends Serializable with Logging { } private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + // Some dialects do not support boolean type and this convenient util function is + // provided to generate SQL string without boolean values. + protected def inputToSQLNoBool(input: Expression): String = input match { + case p: Predicate if p.name() == "ALWAYS_TRUE" => "1" + case p: Predicate if p.name() == "ALWAYS_FALSE" => "0" + case p: Predicate => predicateToIntSQL(inputToSQL(p)) + case _ => inputToSQL(input) + } + + protected def predicateToIntSQL(input: String): String = + "CASE WHEN " + input + " THEN 1 ELSE 0 END" + override def visitLiteral(literal: Literal[_]): String = { Option(literal.value()).map(v => compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 7d476d43e5c7..7d339a90db8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -59,6 +59,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr supportedFunctions.contains(funcName) class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override protected def predicateToIntSQL(input: String): String = + "IIF(" + input + ", 1, 0)" override def visitSortOrder( sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { (sortDirection, nullOrdering) match { @@ -87,12 +89,24 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr expr match { case e: Predicate => e.name() match { case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - val Array(l, r) = e.children().map { - case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 END" - case o => inputToSQL(o) - } + val Array(l, r) = e.children().map(inputToSQLNoBool) visitBinaryComparison(e.name(), l, r) - case "CASE_WHEN" => visitCaseWhen(expressionsToStringArray(e.children())) + " = 1" + case "CASE_WHEN" => + // Since MsSqlServer cannot handle boolean expressions inside + // a CASE WHEN, it is necessary to convert those to another + // CASE WHEN expression that will return 1 or 0 depending on + // the result. + // Example: + // In: ... CASE WHEN a = b THEN c = d ... END + // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1 + val stringArray = e.children().grouped(2).flatMap { + case Array(whenExpression, thenExpression) => + Array(inputToSQL(whenExpression), inputToSQLNoBool(thenExpression)) + case Array(elseExpression) => + Array(inputToSQLNoBool(elseExpression)) + }.toArray + + visitCaseWhen(stringArray) + " = 1" case _ => super.build(expr) } case _ => super.build(expr)