diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 2a01102614908..541b88a5027d1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -28,7 +28,13 @@ import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Avg; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Sum; import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc; import org.apache.spark.sql.types.DataType; @@ -166,9 +172,31 @@ public String build(Expression expr) { default: return visitUnexpectedExpr(expr); } + } else if (expr instanceof Min) { + Min min = (Min) expr; + return visitAggregateFunction("MIN", false, + Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Max) { + Max max = (Max) expr; + return visitAggregateFunction("MAX", false, + Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Count) { + Count count = (Count) expr; + return visitAggregateFunction("COUNT", count.isDistinct(), + Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Sum) { + Sum sum = (Sum) expr; + return visitAggregateFunction("SUM", sum.isDistinct(), + Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof CountStar) { + return visitAggregateFunction("COUNT", false, new String[]{"*"}); + } else if (expr instanceof Avg) { + Avg avg = (Avg) expr; + return visitAggregateFunction("AVG", avg.isDistinct(), + Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new)); } else if (expr instanceof GeneralAggregateFunc) { GeneralAggregateFunc f = (GeneralAggregateFunc) expr; - return visitGeneralAggregateFunction(f.name(), f.isDistinct(), + return visitAggregateFunction(f.name(), f.isDistinct(), Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); } else if (expr instanceof UserDefinedScalarFunc) { UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr; @@ -290,7 +318,7 @@ protected String visitSQLFunction(String funcName, String[] inputs) { return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; } - protected String visitGeneralAggregateFunction( + protected String visitAggregateFunction( String funcName, boolean isDistinct, String[] inputs) { if (isDistinct) { return funcName + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 35293b38db780..a3637e572669e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.{SQLException, Types} import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -31,34 +33,34 @@ private object DB2Dialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})") - case _ => None - } - ) + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class DB2SQLBuilder extends JDBCSQLBuilder { + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARIANCE" + case "VAR_SAMP" => "VARIANCE_SAMP" + case "STDDEV_POP" => "STDDEV" + case "STDDEV_SAMP" => "STDDEV_SAMP" + case "COVAR_POP" => "COVARIANCE" + case "COVAR_SAMP" => "COVARIANCE_SAMP" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val db2SQLBuilder = new DB2SQLBuilder() + try { + Some(db2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 36c3c6be4a05c..439e0697d9f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -31,25 +30,12 @@ private object DerbyDialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") // See https://db.apache.org/derby/docs/10.15/ref/index.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index f96dd5559f6e8..e58473bb2b31a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.Expression -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType} @@ -36,7 +35,13 @@ private[sql] object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") - private val supportedFunctions = + private val distinctUnsupportedAggregateFunctions = + Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions + + private val supportedFunctions = supportedAggregateFunctions ++ Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN", @@ -45,51 +50,6 @@ private[sql] object H2Dialect extends JdbcDialect { override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "REGR_INTERCEPT" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"REGR_INTERCEPT(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "REGR_R2" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"REGR_R2(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "REGR_SLOPE" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"REGR_SLOPE(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "REGR_SXY" && !f.isDistinct => - assert(f.children().length == 2) - Some(s"REGR_SXY(${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", Types.CLOB)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) @@ -136,9 +96,9 @@ private[sql] object H2Dialect extends JdbcDialect { } override def compileExpression(expr: Expression): Option[String] = { - val jdbcSQLBuilder = new H2JDBCSQLBuilder() + val h2SQLBuilder = new H2SQLBuilder() try { - Some(jdbcSQLBuilder.build(expr)) + Some(h2SQLBuilder.build(expr)) } catch { case NonFatal(e) => logWarning("Error occurs while compiling V2 expression", e) @@ -146,7 +106,15 @@ private[sql] object H2Dialect extends JdbcDialect { } } - class H2JDBCSQLBuilder extends JDBCSQLBuilder { + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) + } override def visitExtract(field: String, source: String): String = { val newField = field match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 1c4d2cf0aec04..ba3a3f50ecad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} @@ -244,7 +244,7 @@ abstract class JdbcDialect extends Serializable with Logging { override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { if (isSupportedFunction(funcName)) { - s"""$funcName(${inputs.mkString(", ")})""" + s"""${dialectFunctionName(funcName)}(${inputs.mkString(", ")})""" } else { // The framework will catch the error and give up the push-down. // Please see `JdbcDialect.compileExpression(expr: Expression)` for more details. @@ -253,6 +253,18 @@ abstract class JdbcDialect extends Serializable with Logging { } } + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = { + if (isSupportedFunction(funcName)) { + super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support aggregate function: $funcName"); + } + } + + protected def dialectFunctionName(funcName: String): String = funcName + override def visitOverlay(inputs: Array[String]): String = { if (isSupportedFunction("OVERLAY")) { super.visitOverlay(inputs) @@ -303,26 +315,8 @@ abstract class JdbcDialect extends Serializable with Logging { * @return Converted value. */ @Since("3.3.0") - def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - aggFunction match { - case min: Min => - compileExpression(min.column).map(v => s"MIN($v)") - case max: Max => - compileExpression(max.column).map(v => s"MAX($v)") - case count: Count => - val distinct = if (count.isDistinct) "DISTINCT " else "" - compileExpression(count.column).map(v => s"COUNT($distinct$v)") - case sum: Sum => - val distinct = if (sum.isDistinct) "DISTINCT " else "" - compileExpression(sum.column).map(v => s"SUM($distinct$v)") - case _: CountStar => - Some("COUNT(*)") - case avg: Avg => - val distinct = if (avg.isDistinct) "DISTINCT " else "" - compileExpression(avg.column).map(v => s"AVG($distinct$v)") - case _ => None - } - } + @deprecated("use org.apache.spark.sql.jdbc.JdbcDialect.compileExpression instead.", "3.4.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = compileExpression(aggFunction) /** * List the user-defined functions in jdbc dialect. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index c95489a28761b..625b3eef7fbc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -53,28 +55,32 @@ private object MsSqlServerDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEVP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEV($distinct${f.children().head})") - case _ => None - } - ) + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARP" + case "VAR_SAMP" => "VAR" + case "STDDEV_POP" => "STDEVP" + case "STDDEV_SAMP" => "STDEV" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val msSqlServerSQLBuilder = new MsSqlServerSQLBuilder() + try { + Some(msSqlServerSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index c4cb5369af9e7..96b544bb03ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -39,25 +38,12 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 40333c1757c4a..820bff354ca5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -37,34 +36,12 @@ private case object OracleDialect extends JdbcDialect { // scalastyle:off line.size.limit // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index a668d66ee2f9a..551f8d6262191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -37,41 +36,12 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") // See https://www.postgresql.org/docs/8.4/functions-aggregate.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 2b2d1fb7e8630..427aaf9dc9975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -32,38 +31,12 @@ private case object TeradataDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 4156ae5b27928..2dd6280091bc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLim import org.apache.spark.sql.connector.{IntegralAverage, StrLen} import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.Expression -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, UserDefinedAggregateFunc} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, coalesce, cos, cosh, cot, count, count_distinct, degrees, exp, floor, lit, log => logarithm, log10, not, pow, radians, round, signum, sin, sinh, sqrt, sum, tan, tanh, udf, when} @@ -66,9 +65,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel canonicalName match { case "h2.iavg" => if (isDistinct) { - s"$funcName(DISTINCT ${inputs.mkString(", ")})" + s"AVG(DISTINCT ${inputs.mkString(", ")})" } else { - s"$funcName(${inputs.mkString(", ")})" + s"AVG(${inputs.mkString(", ")})" } case _ => super.visitUserDefinedAggregateFunction(funcName, canonicalName, isDistinct, inputs) @@ -87,18 +86,6 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: UserDefinedAggregateFunc if f.name() == "iavg" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - compileExpression(f.children().head).map(v => s"AVG($distinct$v)") - case _ => None - } - ) - } - override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions } @@ -1815,7 +1802,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df1, """ |PushedAggregates: [REGR_INTERCEPT(BONUS, BONUS), REGR_R2(BONUS, BONUS), - |REGR_SLOPE(BONUS, BONUS), REGR_SXY(BONUS, B..., + |REGR_SLOPE(BONUS, BONUS), REGR_SXY(BONUS, BONUS)], |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], |PushedGroupByExpressions: [DEPT], |""".stripMargin.replaceAll("\n", " "))