From 87561b79b84ec623c2e5fe5964aac757c6da5e44 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 18 Dec 2023 15:45:14 +0800 Subject: [PATCH 1/2] [SPARK-46442][SQL] DS V2 supports push down PERCENTILE_CONT and PERCENTILE_DISC --- .../aggregate/GeneralAggregateFunc.java | 20 +++++- .../util/V2ExpressionSQLBuilder.java | 21 ++++++- .../catalyst/util/V2ExpressionBuilder.scala | 20 +++++- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 15 +---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 17 ++++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 62 ++++++++++++++++++- 6 files changed, 131 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 4d787eaf9644a..488db74e3161a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -41,7 +41,9 @@ *
  • REGR_R2(input1, input2)
    Since 3.4.0
  • *
  • REGR_SLOPE(input1, input2)
    Since 3.4.0
  • *
  • REGR_SXY(input1, input2)
    Since 3.4.0
  • - *
  • MODE(input1[, inverse])
    Since 4.0.0
  • + *
  • MODE() WITHIN (ORDER BY input1 [ASC|DESC])
    Since 4.0.0
  • + *
  • PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])
    Since 4.0.0
  • + *
  • PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])
    Since 4.0.0
  • * * * @since 3.3.0 @@ -51,11 +53,21 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement private final String name; private final boolean isDistinct; private final Expression[] children; + private final Expression[] orderingWithinGroups; public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; this.children = children; + this.orderingWithinGroups = null; + } + + public GeneralAggregateFunc( + String name, boolean isDistinct, Expression[] children, Expression[] orderingWithinGroups) { + this.name = name; + this.isDistinct = isDistinct; + this.children = children; + this.orderingWithinGroups = orderingWithinGroups; } public String name() { return name; } @@ -64,6 +76,8 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr @Override public Expression[] children() { return children; } + public Expression[] orderingWithinGroups() { return orderingWithinGroups; } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -73,7 +87,8 @@ public boolean equals(Object o) { if (isDistinct != that.isDistinct) return false; if (!name.equals(that.name)) return false; - return Arrays.equals(children, that.children); + if (!Arrays.equals(children, that.children)) return false; + return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups); } @Override @@ -81,6 +96,7 @@ public int hashCode() { int result = name.hashCode(); result = 31 * result + (isDistinct ? 1 : 0); result = 31 * result + Arrays.hashCode(children); + result = 31 * result + Arrays.hashCode(orderingWithinGroups); return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 5cd28f1c25984..90f17ca783ee2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -144,8 +144,16 @@ yield visitBinaryArithmetic( return visitAggregateFunction("AVG", avg.isDistinct(), expressionsToStringArray(avg.children())); } else if (expr instanceof GeneralAggregateFunc f) { - return visitAggregateFunction(f.name(), f.isDistinct(), - expressionsToStringArray(f.children())); + if (f.orderingWithinGroups() == null) { + return visitAggregateFunction(f.name(), f.isDistinct(), + expressionsToStringArray(f.children())); + } else { + return visitInverseDistributionFunction( + f.name(), + f.isDistinct(), + expressionsToStringArray(f.children()), + expressionsToStringArray(f.orderingWithinGroups())); + } } else if (expr instanceof UserDefinedScalarFunc f) { return visitUserDefinedScalarFunction(f.name(), f.canonicalName(), expressionsToStringArray(f.children())); @@ -271,6 +279,15 @@ protected String visitAggregateFunction( } } + protected String visitInverseDistributionFunction( + String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) { + assert(isDistinct == false); + String withinGroup = + joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")"); + String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")"); + return functionCall + " " + withinGroup; + } + protected String visitUserDefinedScalarFunction( String funcName, String canonicalName, String[] inputs) { throw new UnsupportedOperationException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 2766bbaa88805..3942d193a3284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.connector.catalog.functions.ScalarFunction -import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression @@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) // Translate Mode if it is deterministic or reverse is defined. case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) => - Some(new GeneralAggregateFunc("MODE", isDistinct, - Array(expr, LiteralValue(reverse, BooleanType)))) + Some(new GeneralAggregateFunc( + "MODE", isDistinct, Array.empty, Array(generateSortValue(expr, !reverse)))) + case aggregate.Percentile( + PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, _, reverse) => + Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct, + Array(right), Array(generateSortValue(left, reverse)))) + case aggregate.PercentileDisc( + PushableExpression(left), PushableExpression(right), reverse, _, _, _) => + Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct, + Array(right), Array(generateSortValue(left, reverse)))) // TODO supports other aggregate functions case aggregate.V2Aggregator(aggrFunc, children, _, _) => val translatedExprs = children.flatMap(PushableExpression.unapply(_)) @@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { None } } + + private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) { + SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST) + } else { + SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + } } object ColumnOrField { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index c3b4092c8e37f..76ea49a814924 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -42,7 +42,7 @@ private[sql] object H2Dialect extends JdbcDialect { private val distinctUnsupportedAggregateFunctions = Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY", - "MODE") + "MODE", "PERCENTILE_CONT", "PERCENTILE_DISC") private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions @@ -270,18 +270,7 @@ private[sql] object H2Dialect extends JdbcDialect { throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + s"support aggregate function: $funcName with DISTINCT") } else { - funcName match { - case "MODE" => - // Support Mode only if it is deterministic or reverse is defined. - assert(inputs.length == 2) - if (inputs.last == "true") { - s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})" - } else { - s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)" - } - case _ => - super.visitAggregateFunction(funcName, isDistinct, inputs) - } + super.visitAggregateFunction(funcName, isDistinct, inputs) } override def visitExtract(field: String, source: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 4825568d88eb0..cea01f7f1b980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with Logging { super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs) } else { throw new UnsupportedOperationException( - s"${this.getClass.getSimpleName} does not support aggregate function: $funcName"); + s"${this.getClass.getSimpleName} does not support aggregate function: $funcName") + } + } + + override def visitInverseDistributionFunction( + funcName: String, + isDistinct: Boolean, + inputs: Array[String], + orderingWithinGroups: Array[String]): String = { + if (isSupportedFunction(funcName)) { + super.visitInverseDistributionFunction( + dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support " + + s"inverse distribution function: $funcName") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5e04fca92f4b0..a3990f3cfbb35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df1) checkPushedInfo(df1, """ - |PushedAggregates: [MODE(SALARY, true)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) @@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df3) checkPushedInfo(df3, """ - |PushedAggregates: [MODE(SALARY, true)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) @@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df4) checkPushedInfo(df4, """ - |PushedAggregates: [MODE(SALARY, false)], + |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " ")) checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00))) } + test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with filter and group by") { + val df1 = sql( + """ + |SELECT + | dept, + | PERCENTILE(salary, 0.5) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, + """ + |PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00))) + + val df2 = sql( + """ + |SELECT + | dept, + | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY), + | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df2) + checkAggregateRemoved(df2) + checkPushedInfo(df2, + """ + |PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST), + |PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df2, + Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 12000.0))) + + val df3 = sql( + """ + |SELECT + | dept, + | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY), + | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df3) + checkAggregateRemoved(df3) + checkPushedInfo(df3, + """ + |PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST), + |PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, + Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 12000.0))) + } + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*) From 7a89d7726f81f6a945924757d2443b57649bcec9 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sun, 7 Jan 2024 21:13:48 +0800 Subject: [PATCH 2/2] Update code --- .../expressions/aggregate/GeneralAggregateFunc.java | 9 +++++---- .../spark/sql/connector/util/V2ExpressionSQLBuilder.java | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 488db74e3161a..d287288ba33fb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.SortValue; import org.apache.spark.sql.internal.connector.ExpressionWithToString; /** @@ -53,17 +54,17 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement private final String name; private final boolean isDistinct; private final Expression[] children; - private final Expression[] orderingWithinGroups; + private final SortValue[] orderingWithinGroups; public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; this.children = children; - this.orderingWithinGroups = null; + this.orderingWithinGroups = new SortValue[]{}; } public GeneralAggregateFunc( - String name, boolean isDistinct, Expression[] children, Expression[] orderingWithinGroups) { + String name, boolean isDistinct, Expression[] children, SortValue[] orderingWithinGroups) { this.name = name; this.isDistinct = isDistinct; this.children = children; @@ -76,7 +77,7 @@ public GeneralAggregateFunc( @Override public Expression[] children() { return children; } - public Expression[] orderingWithinGroups() { return orderingWithinGroups; } + public SortValue[] orderingWithinGroups() { return orderingWithinGroups; } @Override public boolean equals(Object o) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 90f17ca783ee2..7b930c70faafb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -144,7 +144,7 @@ yield visitBinaryArithmetic( return visitAggregateFunction("AVG", avg.isDistinct(), expressionsToStringArray(avg.children())); } else if (expr instanceof GeneralAggregateFunc f) { - if (f.orderingWithinGroups() == null) { + if (f.orderingWithinGroups().length == 0) { return visitAggregateFunction(f.name(), f.isDistinct(), expressionsToStringArray(f.children())); } else {