From 668278bf59b8719b5011bad9dd4c6f1fa947b1a3 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 11 Apr 2022 16:52:54 +0800 Subject: [PATCH 1/3] [SPARK-38855][SQL] DS V2 supports push down math functions --- .../util/V2ExpressionSQLBuilder.java | 7 ++++ .../catalyst/util/V2ExpressionBuilder.scala | 31 +++++++++++++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 41 ++++++++++++++++++- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index a7d1ed7f85e84..c9dfa2003e3c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -95,6 +95,13 @@ public String build(Expression expr) { return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "ABS": case "COALESCE": + case "LN": + case "EXP": + case "POWER": + case "SQRT": + case "FLOOR": + case "CEIL": + case "WIDTH_BUCKET": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 37db499470aa3..32e902f56d667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -104,6 +104,35 @@ class V2ExpressionBuilder( } else { None } + case Log(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) + case Exp(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Pow(left, right) => + val l = generateExpression(left) + val r = generateExpression(right) + if (l.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) + } else { + None + } + case Sqrt(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) + case Floor(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) + case Ceil(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case WidthBucket(value, minValue, maxValue, numBucket) => + val v = generateExpression(value) + val min = generateExpression(minValue) + val max = generateExpression(maxValue) + val n = generateExpression(numBucket) + if (v.isDefined && min.isDefined && max.isDefined && n.isDefined) { + Some(new GeneralScalarExpression("WIDTH_BUCKET", + Array[V2Expression](v.get, min.get, max.get, n.get))) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index e60af877e9c5a..7c40e9d4c7c77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -464,6 +464,45 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee") + .filter(ln($"dept") > 1) + .filter(exp($"salary") > 2000) + .filter(pow($"dept", 2) > 4) + .filter(sqrt($"salary") > 100) + .filter(floor($"dept") > 1) + .filter(ceil($"dept") > 1) + checkFiltersRemoved(df6, ansiMode) + val expectedPlanFragment6 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " + + "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...," + } else { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]" + } + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + + val df7 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE width_bucket(dept, 1, 6, 3) > 1 + |""".stripMargin) + checkFiltersRemoved(df7, ansiMode) + val expectedPlanFragment7 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, " + + "(WIDTH_BUCKET(CAST(DEPT AS double), 1.0, 6.0, 3)) > 1]" + } else { + "PushedFilters: [DEPT IS NOT NULL]" + } + checkPushedInfo(df7, expectedPlanFragment7) + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df7, Seq.empty) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLSyntaxErrorException: Function \"WIDTH_BUCKET\" not found;")) + } else { + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) + } } } } From 22e9fe8133224d02d932deaa34fe20fa19edae1e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 12 Apr 2022 19:19:18 +0800 Subject: [PATCH 2/3] Update code --- .../expressions/GeneralScalarExpression.java | 54 +++++++++++++++++++ .../sql/errors/QueryCompilationErrors.scala | 4 ++ .../org/apache/spark/sql/jdbc/H2Dialect.scala | 26 +++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 21 ++------ 4 files changed, 88 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 8952761f9ef34..58082d5ee09c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -94,6 +94,60 @@ *
  • Since version: 3.3.0
  • * * + *
  • Name: ABS + *
      + *
    • SQL semantic: ABS(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: COALESCE + *
      + *
    • SQL semantic: COALESCE(expr1, expr2)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: LN + *
      + *
    • SQL semantic: LN(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: EXP + *
      + *
    • SQL semantic: EXP(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: POWER + *
      + *
    • SQL semantic: POWER(expr, number)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: SQRT + *
      + *
    • SQL semantic: SQRT(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: FLOOR + *
      + *
    • SQL semantic: FLOOR(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: CEIL + *
      + *
    • SQL semantic: CEIL(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • + *
  • Name: WIDTH_BUCKET + *
      + *
    • SQL semantic: WIDTH_BUCKET(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b6b64804904ee..f743844ebd816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2380,4 +2380,8 @@ object QueryCompilationErrors { new AnalysisException( "Sinks cannot request distribution and ordering in continuous execution mode") } + + def noSuchFunctionError(database: String, funcInfo: String): Throwable = { + new AnalysisException(s"$database does not support function: $funcInfo") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 643376cdb126a..0aa971c0d3ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.errors.QueryCompilationErrors private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "WIDTH_BUCKET" => + val functionInfo = super.visitSQLFunction(funcName, inputs) + throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) + case _ => super.visitSQLFunction(funcName, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 7c40e9d4c7c77..5cfa2f465a2be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -482,27 +482,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + // H2 does not support width_bucket val df7 = sql(""" |SELECT * FROM h2.test.employee |WHERE width_bucket(dept, 1, 6, 3) > 1 |""".stripMargin) - checkFiltersRemoved(df7, ansiMode) - val expectedPlanFragment7 = if (ansiMode) { - "PushedFilters: [DEPT IS NOT NULL, " + - "(WIDTH_BUCKET(CAST(DEPT AS double), 1.0, 6.0, 3)) > 1]" - } else { - "PushedFilters: [DEPT IS NOT NULL]" - } - checkPushedInfo(df7, expectedPlanFragment7) - if (ansiMode) { - val e = intercept[SparkException] { - checkAnswer(df7, Seq.empty) - } - assert(e.getMessage.contains( - "org.h2.jdbc.JdbcSQLSyntaxErrorException: Function \"WIDTH_BUCKET\" not found;")) - } else { - checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) - } + checkFiltersRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) } } } From 2f74b0368af02329eb8184b532c884442becd842 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 13 Apr 2022 08:51:42 +0800 Subject: [PATCH 3/3] Update code --- .../spark/sql/catalyst/util/V2ExpressionBuilder.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 32e902f56d667..487b809d48a01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -122,14 +122,11 @@ class V2ExpressionBuilder( .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) case Ceil(child) => generateExpression(child) .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) - case WidthBucket(value, minValue, maxValue, numBucket) => - val v = generateExpression(value) - val min = generateExpression(minValue) - val max = generateExpression(maxValue) - val n = generateExpression(numBucket) - if (v.isDefined && min.isDefined && max.isDefined && n.isDefined) { + case wb: WidthBucket => + val childrenExpressions = wb.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == wb.children.length) { Some(new GeneralScalarExpression("WIDTH_BUCKET", - Array[V2Expression](v.get, min.get, max.get, n.get))) + childrenExpressions.toArray[V2Expression])) } else { None }