From 5f6223c65b01df17044d2e15583505cd0419b8d6 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Tue, 19 Mar 2024 18:48:09 +0800 Subject: [PATCH 1/9] Use V2Predicate to wrap If when building v2 expressions --- .../catalyst/util/V2ExpressionBuilder.scala | 8 +++++- .../sql/connector/DataSourceV2Suite.scala | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 3942d193a328..54b57a53493e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -209,7 +209,13 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) + case iff: If => + val childrenExpressions = iff.children.flatMap(generateExpression(_)) + if (iff.children.length == childrenExpressions.length) { + Some(new V2Predicate("CASE_WHEN", childrenExpressions.toArray[V2Expression])) + } else { + None + } case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a7fb2c054e80..3c3295f2a7ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -966,6 +966,32 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS ) } } + + test("SPARK-47463: Pushed down v2 filter that folded predicate into (if / case) branches") { + withTempView("t1") { + spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + .createTempView("t1") + val df1 = sql( + s""" + |select * from + |(select if(i = 1, i, 0) as c from t1) t + |where t.c > 0 + |""".stripMargin + ) + val result1 = df1.collect() + assert(result1.length == 1) + + val df2 = sql( + s""" + |select * from + |(select case when i = 1 then i else 0 end as c from t1) t + |where t.c > 0 + |""".stripMargin + ) + val result2 = df2.collect() + assert(result2.length == 1) + } + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition From 065345bda5182e20ec557976f8a29f4fa171e57c Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Wed, 10 Apr 2024 10:45:44 +0800 Subject: [PATCH 2/9] consider isPredicate --- .../sql/catalyst/util/V2ExpressionBuilder.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 54b57a53493e..e89e22a87288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -210,11 +210,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { None } case iff: If => - val childrenExpressions = iff.children.flatMap(generateExpression(_)) - if (iff.children.length == childrenExpressions.length) { - Some(new V2Predicate("CASE_WHEN", childrenExpressions.toArray[V2Expression])) + if (isPredicate && iff.dataType.isInstanceOf[BooleanType]) { + generatePredicateWithName("CASE_WHEN", iff.children) } else { - None + generateExpressionWithName("CASE_WHEN", iff.children) } case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { @@ -395,6 +394,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } } + private def generatePredicateWithName( + v2PredicateName: String, children: Seq[Expression]): Option[V2Expression] = { + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new V2Predicate(v2PredicateName, childrenExpressions.toArray[V2Expression])) + } else { + None + } + } + private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) { SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST) } else { From efb183ebb38c444096cc254e6a837b54cd00724c Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Wed, 10 Apr 2024 10:55:45 +0800 Subject: [PATCH 3/9] add `nullif` test case --- .../spark/sql/connector/DataSourceV2Suite.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 3c3295f2a7ee..b622c5060c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -967,7 +967,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } - test("SPARK-47463: Pushed down v2 filter that folded predicate into (if / case) branches") { + test("SPARK-47463: Pushed down v2 filter with (if / case when/ nullif) expression") { withTempView("t1") { spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() .createTempView("t1") @@ -990,6 +990,15 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS ) val result2 = df2.collect() assert(result2.length == 1) + + val df3 = sql( + s""" + |select * from t1 + |where nullif(i, 1) is null + |""".stripMargin + ) + val result3 = df3.collect() + assert(result3.length == 1) } } } From 07b88aeecb7221017e067fccae79e2e7c355beed Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Thu, 11 Apr 2024 13:30:57 +0800 Subject: [PATCH 4/9] address comments --- .../catalyst/util/V2ExpressionBuilder.scala | 154 +++++++++--------- .../sql/connector/DataSourceV2Suite.scala | 61 +++---- 2 files changed, 110 insertions(+), 105 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index e89e22a87288..f456fa6bc4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression -import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} /** * The builder to generate V2 expressions from catalyst expressions. @@ -98,45 +98,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Cast(v, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) - case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) - case Coalesce(children) => generateExpressionWithName("COALESCE", children) - case Greatest(children) => generateExpressionWithName("GREATEST", children) - case Least(children) => generateExpressionWithName("LEAST", children) - case Rand(child, hideSeed) => + case abs @ Abs(_, true) => generateExpressionWithName("ABS", abs, isPredicate) + case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate) + case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate) + case _: Least => generateExpressionWithName("LEAST", expr, isPredicate) + case rand @ Rand(_, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpressionWithName("RAND", Seq(child)) + generateExpressionWithName("RAND", rand, isPredicate) } - case log: Logarithm => generateExpressionWithName("LOG", log.children) - case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) - case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) - case Log(child) => generateExpressionWithName("LN", Seq(child)) - case Exp(child) => generateExpressionWithName("EXP", Seq(child)) - case pow: Pow => generateExpressionWithName("POWER", pow.children) - case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) - case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) - case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) - case round: Round => generateExpressionWithName("ROUND", round.children) - case Sin(child) => generateExpressionWithName("SIN", Seq(child)) - case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) - case Cos(child) => generateExpressionWithName("COS", Seq(child)) - case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) - case Tan(child) => generateExpressionWithName("TAN", Seq(child)) - case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) - case Cot(child) => generateExpressionWithName("COT", Seq(child)) - case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) - case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) - case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) - case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) - case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) - case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) - case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) - case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) - case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) - case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) - case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) - case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) + case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate) + case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate) + case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate) + case _: Log => generateExpressionWithName("LN", expr, isPredicate) + case _: Exp => generateExpressionWithName("EXP", expr, isPredicate) + case _: Pow => generateExpressionWithName("POWER", expr, isPredicate) + case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate) + case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate) + case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate) + case _: Round => generateExpressionWithName("ROUND", expr, isPredicate) + case _: Sin => generateExpressionWithName("SIN", expr, isPredicate) + case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate) + case _: Cos => generateExpressionWithName("COS", expr, isPredicate) + case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate) + case _: Tan => generateExpressionWithName("TAN", expr, isPredicate) + case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate) + case _: Cot => generateExpressionWithName("COT", expr, isPredicate) + case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate) + case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate) + case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate) + case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate) + case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate) + case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate) + case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate) + case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate) + case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate) + case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate) + case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate) + case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -187,8 +187,9 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) - case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) + case unaryMinus @ UnaryMinus(_, true) => + generateExpressionWithName("-", unaryMinus, isPredicate) + case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) val values = branches.map(_._2).flatMap(generateExpression(_, true)) @@ -209,40 +210,35 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case iff: If => - if (isPredicate && iff.dataType.isInstanceOf[BooleanType]) { - generatePredicateWithName("CASE_WHEN", iff.children) - } else { - generateExpressionWithName("CASE_WHEN", iff.children) - } + case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - generateExpressionWithName("SUBSTRING", children) - case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) - case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) - case BitLength(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("BIT_LENGTH", Seq(child)) - case Length(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("CHAR_LENGTH", Seq(child)) - case concat: Concat => generateExpressionWithName("CONCAT", concat.children) - case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) - case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) - case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) - case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) + generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate) + case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate) + case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate) + case bitLength @ BitLength(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("BIT_LENGTH", bitLength, isPredicate) + case length @ Length(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("CHAR_LENGTH", length, isPredicate) + case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate) + case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate) + case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate) + case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate) + case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - generateExpressionWithName("OVERLAY", children) - case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) - case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) - case date: TruncDate => generateExpressionWithName("TRUNC", date.children) + generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate) + case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate) + case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate) + case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -275,12 +271,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Extract("WEEK", v)) case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) - case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children) - case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children) - case Crc32(child) => generateExpressionWithName("CRC32", Seq(child)) - case Md5(child) => generateExpressionWithName("MD5", Seq(child)) - case Sha1(child) => generateExpressionWithName("SHA1", Seq(child)) - case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children) + case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate) + case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate) + case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate) + case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate) + case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate) + case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) @@ -385,20 +381,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } private def generateExpressionWithName( - v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) - } else { - None - } + v2ExpressionName: String, + expr: Expression, + isPredicate: Boolean): Option[V2Expression] = { + generateExpressionWithNameByChildren( + v2ExpressionName, expr.children, expr.dataType, isPredicate) } - private def generatePredicateWithName( - v2PredicateName: String, children: Seq[Expression]): Option[V2Expression] = { + private def generateExpressionWithNameByChildren( + v2ExpressionName: String, + children: Seq[Expression], + dataType: DataType, + isPredicate: Boolean): Option[V2Expression] = { val childrenExpressions = children.flatMap(generateExpression(_)) if (childrenExpressions.length == children.length) { - Some(new V2Predicate(v2PredicateName, childrenExpressions.toArray[V2Expression])) + if (isPredicate && dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + Some(new GeneralScalarExpression( + v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index b622c5060c01..57964e84cf94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -967,38 +967,41 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } - test("SPARK-47463: Pushed down v2 filter with (if / case when/ nullif) expression") { + test("SPARK-47463: Pushed down v2 filter with (if / case when / nvl) expression") { withTempView("t1") { - spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() - .createTempView("t1") - val df1 = sql( - s""" - |select * from - |(select if(i = 1, i, 0) as c from t1) t - |where t.c > 0 - |""".stripMargin - ) - val result1 = df1.collect() - assert(result1.length == 1) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + .createTempView("t1") + val df1 = sql( + s""" + |select * from + |(select if(i = 1, i, 0) as c from t1) t + |where t.c > 0 + |""".stripMargin + ) + val result1 = df1.collect() + assert(result1.length == 1) - val df2 = sql( - s""" - |select * from - |(select case when i = 1 then i else 0 end as c from t1) t - |where t.c > 0 - |""".stripMargin - ) - val result2 = df2.collect() - assert(result2.length == 1) + val df2 = sql( + s""" + |select * from + |(select case when i = 1 then i else 0 end as c from t1) t + |where t.c > 0 + |""".stripMargin + ) + val result2 = df2.collect() + assert(result2.length == 1) - val df3 = sql( - s""" - |select * from t1 - |where nullif(i, 1) is null - |""".stripMargin - ) - val result3 = df3.collect() - assert(result3.length == 1) + val df3 = sql( + s""" + |select * from + |(select nvl(cast(i as boolean), false) c from t1) t + |where t.c is true + |""".stripMargin + ) + val result3 = df3.collect() + assert(result3.length > 0) + } } } } From a8a5ed59f0804f0b70473faf751b023b1610d441 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Fri, 12 Apr 2024 10:13:32 +0800 Subject: [PATCH 5/9] fix casewhen --- .../catalyst/util/V2ExpressionBuilder.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index f456fa6bc4a1..315faf15ffc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -190,22 +190,21 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case unaryMinus @ UnaryMinus(_, true) => generateExpressionWithName("-", unaryMinus, isPredicate) case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) - case CaseWhen(branches, elseValue) => + case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_, true)) - if (conditions.length == branches.length && values.length == branches.length) { + val values = branches.map(_._2).flatMap(generateExpression(_, false)) + val elseExprOpt = elseValue.flatMap(generateExpression(_)) + if (conditions.length == branches.length && values.length == branches.length && + elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => Seq[V2Expression](c, v) } - if (elseValue.isDefined) { - elseValue.flatMap(generateExpression(_)).map { v => - val children = (branchExpressions :+ v).toArray[V2Expression] - // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] - new V2Predicate("CASE_WHEN", children) - } + val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)] + if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("CASE_WHEN", children)) } else { - // The children looks like [condition1, value1, ..., conditionN, valueN] - Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression])) + Some(new GeneralScalarExpression("CASE_WHEN", children)) } } else { None From e176461491700419db997fbe741839981269b389 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Fri, 12 Apr 2024 10:15:06 +0800 Subject: [PATCH 6/9] default value --- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 315faf15ffc7..d405eee636b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -192,7 +192,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_, false)) + val values = branches.map(_._2).flatMap(generateExpression(_)) val elseExprOpt = elseValue.flatMap(generateExpression(_)) if (conditions.length == branches.length && values.length == branches.length && elseExprOpt.size == elseValue.size) { From 08426d108c632180fe3a6a89a2698c16fa6d4744 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Fri, 12 Apr 2024 15:10:47 +0800 Subject: [PATCH 7/9] address comments --- .../sql/catalyst/util/V2ExpressionBuilder.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index d405eee636b0..398f21e01b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -98,15 +98,15 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Cast(v, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) - case abs @ Abs(_, true) => generateExpressionWithName("ABS", abs, isPredicate) + case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate) case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate) case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate) case _: Least => generateExpressionWithName("LEAST", expr, isPredicate) - case rand @ Rand(_, hideSeed) => + case Rand(_, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpressionWithName("RAND", rand, isPredicate) + generateExpressionWithName("RAND", expr, isPredicate) } case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate) case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate) @@ -187,8 +187,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case unaryMinus @ UnaryMinus(_, true) => - generateExpressionWithName("-", unaryMinus, isPredicate) + case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate) case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) @@ -219,10 +218,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate) case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate) case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate) - case bitLength @ BitLength(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("BIT_LENGTH", bitLength, isPredicate) - case length @ Length(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("CHAR_LENGTH", length, isPredicate) + case BitLength(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("BIT_LENGTH", expr, isPredicate) + case Length(child) if child.dataType.isInstanceOf[StringType] => + generateExpressionWithName("CHAR_LENGTH", expr, isPredicate) case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate) case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate) case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate) From 3ecacd1df91444b1697ac86bad3e769c46918329 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Fri, 12 Apr 2024 16:32:33 +0800 Subject: [PATCH 8/9] simplify unit test --- .../sql/connector/DataSourceV2Suite.scala | 46 +++++-------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 57964e84cf94..2f50b2ee0824 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -967,41 +967,19 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } - test("SPARK-47463: Pushed down v2 filter with (if / case when / nvl) expression") { + test("SPARK-47463: Pushed down v2 filter with if expression") { withTempView("t1") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() - .createTempView("t1") - val df1 = sql( - s""" - |select * from - |(select if(i = 1, i, 0) as c from t1) t - |where t.c > 0 - |""".stripMargin - ) - val result1 = df1.collect() - assert(result1.length == 1) - - val df2 = sql( - s""" - |select * from - |(select case when i = 1 then i else 0 end as c from t1) t - |where t.c > 0 - |""".stripMargin - ) - val result2 = df2.collect() - assert(result2.length == 1) - - val df3 = sql( - s""" - |select * from - |(select nvl(cast(i as boolean), false) c from t1) t - |where t.c is true - |""".stripMargin - ) - val result3 = df3.collect() - assert(result3.length > 0) - } + spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + .createTempView("t1") + val df1 = sql( + s""" + |select * from + |(select if(i = 1, i, 0) as c from t1) t + |where t.c > 0 + |""".stripMargin + ) + val result1 = df1.collect() + assert(result1.length == 1) } } } From ffebea2979b268e49544da4853dddccf3b698ff8 Mon Sep 17 00:00:00 2001 From: Zhen Wang <643348094@qq.com> Date: Tue, 16 Apr 2024 09:43:18 +0800 Subject: [PATCH 9/9] address comment --- .../spark/sql/connector/DataSourceV2Suite.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 2f50b2ee0824..1de535df246b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -971,15 +971,9 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withTempView("t1") { spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() .createTempView("t1") - val df1 = sql( - s""" - |select * from - |(select if(i = 1, i, 0) as c from t1) t - |where t.c > 0 - |""".stripMargin - ) - val result1 = df1.collect() - assert(result1.length == 1) + val df = sql("SELECT * FROM t1 WHERE if(i = 1, i, 0) > 0") + val result = df.collect() + assert(result.length == 1) } } }