diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 4d787eaf9644a..d287288ba33fb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -21,6 +21,7 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.SortValue;
import org.apache.spark.sql.internal.connector.ExpressionWithToString;
/**
@@ -41,7 +42,9 @@
*
REGR_R2(input1, input2)
Since 3.4.0
* REGR_SLOPE(input1, input2)
Since 3.4.0
* REGR_SXY(input1, input2)
Since 3.4.0
- * MODE(input1[, inverse])
Since 4.0.0
+ * MODE() WITHIN (ORDER BY input1 [ASC|DESC])
Since 4.0.0
+ * PERCENTILE_CONT(input1) WITHIN (ORDER BY input2 [ASC|DESC])
Since 4.0.0
+ * PERCENTILE_DISC(input1) WITHIN (ORDER BY input2 [ASC|DESC])
Since 4.0.0
*
*
* @since 3.3.0
@@ -51,11 +54,21 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement
private final String name;
private final boolean isDistinct;
private final Expression[] children;
+ private final SortValue[] orderingWithinGroups;
public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
+ this.orderingWithinGroups = new SortValue[]{};
+ }
+
+ public GeneralAggregateFunc(
+ String name, boolean isDistinct, Expression[] children, SortValue[] orderingWithinGroups) {
+ this.name = name;
+ this.isDistinct = isDistinct;
+ this.children = children;
+ this.orderingWithinGroups = orderingWithinGroups;
}
public String name() { return name; }
@@ -64,6 +77,8 @@ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] childr
@Override
public Expression[] children() { return children; }
+ public SortValue[] orderingWithinGroups() { return orderingWithinGroups; }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -73,7 +88,8 @@ public boolean equals(Object o) {
if (isDistinct != that.isDistinct) return false;
if (!name.equals(that.name)) return false;
- return Arrays.equals(children, that.children);
+ if (!Arrays.equals(children, that.children)) return false;
+ return Arrays.equals(orderingWithinGroups, that.orderingWithinGroups);
}
@Override
@@ -81,6 +97,7 @@ public int hashCode() {
int result = name.hashCode();
result = 31 * result + (isDistinct ? 1 : 0);
result = 31 * result + Arrays.hashCode(children);
+ result = 31 * result + Arrays.hashCode(orderingWithinGroups);
return result;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 5cd28f1c25984..7b930c70faafb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -144,8 +144,16 @@ yield visitBinaryArithmetic(
return visitAggregateFunction("AVG", avg.isDistinct(),
expressionsToStringArray(avg.children()));
} else if (expr instanceof GeneralAggregateFunc f) {
- return visitAggregateFunction(f.name(), f.isDistinct(),
- expressionsToStringArray(f.children()));
+ if (f.orderingWithinGroups().length == 0) {
+ return visitAggregateFunction(f.name(), f.isDistinct(),
+ expressionsToStringArray(f.children()));
+ } else {
+ return visitInverseDistributionFunction(
+ f.name(),
+ f.isDistinct(),
+ expressionsToStringArray(f.children()),
+ expressionsToStringArray(f.orderingWithinGroups()));
+ }
} else if (expr instanceof UserDefinedScalarFunc f) {
return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
expressionsToStringArray(f.children()));
@@ -271,6 +279,15 @@ protected String visitAggregateFunction(
}
}
+ protected String visitInverseDistributionFunction(
+ String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) {
+ assert(isDistinct == false);
+ String withinGroup =
+ joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")");
+ String functionCall = joinArrayToString(inputs, ", ", funcName + "(", ")");
+ return functionCall + " " + withinGroup;
+ }
+
protected String visitUserDefinedScalarFunction(
String funcName, String canonicalName, String[] inputs) {
throw new UnsupportedOperationException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 2766bbaa88805..3942d193a3284 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
-import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
@@ -347,8 +347,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right)))
// Translate Mode if it is deterministic or reverse is defined.
case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) =>
- Some(new GeneralAggregateFunc("MODE", isDistinct,
- Array(expr, LiteralValue(reverse, BooleanType))))
+ Some(new GeneralAggregateFunc(
+ "MODE", isDistinct, Array.empty, Array(generateSortValue(expr, !reverse))))
+ case aggregate.Percentile(
+ PushableExpression(left), PushableExpression(right), LongLiteral(1L), _, _, reverse) =>
+ Some(new GeneralAggregateFunc("PERCENTILE_CONT", isDistinct,
+ Array(right), Array(generateSortValue(left, reverse))))
+ case aggregate.PercentileDisc(
+ PushableExpression(left), PushableExpression(right), reverse, _, _, _) =>
+ Some(new GeneralAggregateFunc("PERCENTILE_DISC", isDistinct,
+ Array(right), Array(generateSortValue(left, reverse))))
// TODO supports other aggregate functions
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
@@ -380,6 +388,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
None
}
}
+
+ private def generateSortValue(expr: V2Expression, reverse: Boolean): SortValue = if (reverse) {
+ SortValue(expr, SortDirection.DESCENDING, NullOrdering.NULLS_LAST)
+ } else {
+ SortValue(expr, SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)
+ }
}
object ColumnOrField {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index c3b4092c8e37f..76ea49a814924 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -42,7 +42,7 @@ private[sql] object H2Dialect extends JdbcDialect {
private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY",
- "MODE")
+ "MODE", "PERCENTILE_CONT", "PERCENTILE_DISC")
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions
@@ -270,18 +270,7 @@ private[sql] object H2Dialect extends JdbcDialect {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT")
} else {
- funcName match {
- case "MODE" =>
- // Support Mode only if it is deterministic or reverse is defined.
- assert(inputs.length == 2)
- if (inputs.last == "true") {
- s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})"
- } else {
- s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)"
- }
- case _ =>
- super.visitAggregateFunction(funcName, isDistinct, inputs)
- }
+ super.visitAggregateFunction(funcName, isDistinct, inputs)
}
override def visitExtract(field: String, source: String): String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 4825568d88eb0..cea01f7f1b980 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -336,7 +336,22 @@ abstract class JdbcDialect extends Serializable with Logging {
super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs)
} else {
throw new UnsupportedOperationException(
- s"${this.getClass.getSimpleName} does not support aggregate function: $funcName");
+ s"${this.getClass.getSimpleName} does not support aggregate function: $funcName")
+ }
+ }
+
+ override def visitInverseDistributionFunction(
+ funcName: String,
+ isDistinct: Boolean,
+ inputs: Array[String],
+ orderingWithinGroups: Array[String]): String = {
+ if (isSupportedFunction(funcName)) {
+ super.visitInverseDistributionFunction(
+ dialectFunctionName(funcName), isDistinct, inputs, orderingWithinGroups)
+ } else {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} does not support " +
+ s"inverse distribution function: $funcName")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 5e04fca92f4b0..a3990f3cfbb35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -2435,7 +2435,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
- |PushedAggregates: [MODE(SALARY, true)],
+ |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
@@ -2465,7 +2465,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
- |PushedAggregates: [MODE(SALARY, true)],
+ |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
@@ -2481,13 +2481,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAggregateRemoved(df4)
checkPushedInfo(df4,
"""
- |PushedAggregates: [MODE(SALARY, false)],
+ |PushedAggregates: [MODE() WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00)))
}
+ test("scan with aggregate push-down: PERCENTILE & PERCENTILE_DISC with filter and group by") {
+ val df1 = sql(
+ """
+ |SELECT
+ | dept,
+ | PERCENTILE(salary, 0.5)
+ |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+ checkFiltersRemoved(df1)
+ checkAggregateRemoved(df1)
+ checkPushedInfo(df1,
+ """
+ |PushedAggregates: [PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST)],
+ |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+ |PushedGroupByExpressions: [DEPT],
+ |""".stripMargin.replaceAll("\n", " "))
+ checkAnswer(df1, Seq(Row(1, 9500.00), Row(2, 11000.00), Row(6, 12000.00)))
+
+ val df2 = sql(
+ """
+ |SELECT
+ | dept,
+ | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY),
+ | PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
+ |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+ checkFiltersRemoved(df2)
+ checkAggregateRemoved(df2)
+ checkPushedInfo(df2,
+ """
+ |PushedAggregates: [PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
+ |PERCENTILE_CONT(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
+ |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+ |PushedGroupByExpressions: [DEPT],
+ |""".stripMargin.replaceAll("\n", " "))
+ checkAnswer(df2,
+ Seq(Row(1, 9300.0, 9700.0), Row(2, 10600.0, 11400.0), Row(6, 12000.0, 12000.0)))
+
+ val df3 = sql(
+ """
+ |SELECT
+ | dept,
+ | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY),
+ | PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC)
+ |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
+ checkFiltersRemoved(df3)
+ checkAggregateRemoved(df3)
+ checkPushedInfo(df3,
+ """
+ |PushedAggregates: [PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY ASC NULLS FIRST),
+ |PERCENTILE_DISC(0.3) WITHIN GROUP (ORDER BY SALARY DESC NULLS LAST)],
+ |PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
+ |PushedGroupByExpressions: [DEPT],
+ |""".stripMargin.replaceAll("\n", " "))
+ checkAnswer(df3,
+ Seq(Row(1, 9000.0, 10000.0), Row(2, 10000.0, 12000.0), Row(6, 12000.0, 12000.0)))
+ }
+
test("scan with aggregate push-down: aggregate over alias push down") {
val cols = Seq("a", "b", "c", "d", "e")
val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)