Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
import org.apache.spark.sql.connector.expressions.aggregate.Max;
import org.apache.spark.sql.connector.expressions.aggregate.Min;
import org.apache.spark.sql.connector.expressions.aggregate.Count;
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
import org.apache.spark.sql.types.DataType;

Expand Down Expand Up @@ -166,9 +172,31 @@ public String build(Expression expr) {
default:
return visitUnexpectedExpr(expr);
}
} else if (expr instanceof Min) {
Min min = (Min) expr;
return visitAggregateFunction("MIN", false,
Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof Max) {
Max max = (Max) expr;
return visitAggregateFunction("MAX", false,
Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof Count) {
Count count = (Count) expr;
return visitAggregateFunction("COUNT", count.isDistinct(),
Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof Sum) {
Sum sum = (Sum) expr;
return visitAggregateFunction("SUM", sum.isDistinct(),
Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof CountStar) {
return visitAggregateFunction("COUNT", false, new String[]{"*"});
} else if (expr instanceof Avg) {
Avg avg = (Avg) expr;
return visitAggregateFunction("AVG", avg.isDistinct(),
Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof GeneralAggregateFunc) {
GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
return visitGeneralAggregateFunction(f.name(), f.isDistinct(),
return visitAggregateFunction(f.name(), f.isDistinct(),
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof UserDefinedScalarFunc) {
UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
Expand Down Expand Up @@ -290,7 +318,7 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
}

protected String visitGeneralAggregateFunction(
protected String visitAggregateFunction(
String funcName, boolean isDistinct, String[] inputs) {
if (isDistinct) {
return funcName +
Expand Down
60 changes: 31 additions & 29 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc
import java.sql.{SQLException, Types}
import java.util.Locale

import scala.util.control.NonFatal

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.types._

private object DB2Dialect extends JdbcDialect {
Expand All @@ -31,34 +33,34 @@ private object DB2Dialect extends JdbcDialect {
url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")

// See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VARIANCE($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VARIANCE_SAMP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
assert(f.children().length == 2)
Some(s"COVARIANCE(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
assert(f.children().length == 2)
Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})")
case _ => None
}
)
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP")
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

class DB2SQLBuilder extends JDBCSQLBuilder {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update JDBCSQLBuilder to respect supportedFunctions in visitAggregateFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JDBCSQLBuilder already check supportedFunctions.

override def dialectFunctionName(funcName: String): String = funcName match {
case "VAR_POP" => "VARIANCE"
case "VAR_SAMP" => "VARIANCE_SAMP"
case "STDDEV_POP" => "STDDEV"
case "STDDEV_SAMP" => "STDDEV_SAMP"
case "COVAR_POP" => "COVARIANCE"
case "COVAR_SAMP" => "COVARIANCE_SAMP"
case _ => super.dialectFunctionName(funcName)
}
}

override def compileExpression(expr: Expression): Option[String] = {
val db2SQLBuilder = new DB2SQLBuilder()
try {
Some(db2SQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}

override def getCatalystType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.jdbc
import java.sql.Types
import java.util.Locale

import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._

Expand All @@ -31,25 +30,12 @@ private object DerbyDialect extends JdbcDialect {
url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby")

// See https://db.apache.org/derby/docs/10.15/ref/index.html
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
assert(f.children().length == 1)
Some(s"VAR_POP(${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
assert(f.children().length == 1)
Some(s"VAR_SAMP(${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
assert(f.children().length == 1)
Some(s"STDDEV_POP(${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
assert(f.children().length == 1)
Some(s"STDDEV_SAMP(${f.children().head})")
case _ => None
}
)
}
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
Expand Down
68 changes: 18 additions & 50 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,20 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType}

private[sql] object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

private val supportedFunctions =
private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")

private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions

private val supportedFunctions = supportedAggregateFunctions ++
Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP",
"POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN",
"TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN",
Expand All @@ -45,51 +50,6 @@ private[sql] object H2Dialect extends JdbcDialect {
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.children().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.children().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"COVAR_POP(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"CORR(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "REGR_INTERCEPT" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"REGR_INTERCEPT(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "REGR_R2" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"REGR_R2(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "REGR_SLOPE" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"REGR_SLOPE(${f.children().head}, ${f.children().last})")
case f: GeneralAggregateFunc if f.name() == "REGR_SXY" && !f.isDistinct =>
assert(f.children().length == 2)
Some(s"REGR_SXY(${f.children().head}, ${f.children().last})")
case _ => None
}
)
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Option(JdbcType("CLOB", Types.CLOB))
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
Expand Down Expand Up @@ -136,17 +96,25 @@ private[sql] object H2Dialect extends JdbcDialect {
}

override def compileExpression(expr: Expression): Option[String] = {
val jdbcSQLBuilder = new H2JDBCSQLBuilder()
val h2SQLBuilder = new H2SQLBuilder()
try {
Some(jdbcSQLBuilder.build(expr))
Some(h2SQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}

class H2JDBCSQLBuilder extends JDBCSQLBuilder {
class H2SQLBuilder extends JDBCSQLBuilder {
override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT");
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def visitExtract(field: String, source: String): String = {
val newField = field match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils}
Expand Down Expand Up @@ -244,7 +244,7 @@ abstract class JdbcDialect extends Serializable with Logging {

override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
if (isSupportedFunction(funcName)) {
s"""$funcName(${inputs.mkString(", ")})"""
s"""${dialectFunctionName(funcName)}(${inputs.mkString(", ")})"""
} else {
// The framework will catch the error and give up the push-down.
// Please see `JdbcDialect.compileExpression(expr: Expression)` for more details.
Expand All @@ -253,6 +253,18 @@ abstract class JdbcDialect extends Serializable with Logging {
}
}

override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String = {
if (isSupportedFunction(funcName)) {
super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs)
} else {
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support aggregate function: $funcName");
}
}

protected def dialectFunctionName(funcName: String): String = funcName

override def visitOverlay(inputs: Array[String]): String = {
if (isSupportedFunction("OVERLAY")) {
super.visitOverlay(inputs)
Expand Down Expand Up @@ -303,26 +315,8 @@ abstract class JdbcDialect extends Serializable with Logging {
* @return Converted value.
*/
@Since("3.3.0")
def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we deprecate it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

aggFunction match {
case min: Min =>
compileExpression(min.column).map(v => s"MIN($v)")
case max: Max =>
compileExpression(max.column).map(v => s"MAX($v)")
case count: Count =>
val distinct = if (count.isDistinct) "DISTINCT " else ""
compileExpression(count.column).map(v => s"COUNT($distinct$v)")
case sum: Sum =>
val distinct = if (sum.isDistinct) "DISTINCT " else ""
compileExpression(sum.column).map(v => s"SUM($distinct$v)")
case _: CountStar =>
Some("COUNT(*)")
case avg: Avg =>
val distinct = if (avg.isDistinct) "DISTINCT " else ""
compileExpression(avg.column).map(v => s"AVG($distinct$v)")
case _ => None
}
}
@deprecated("use org.apache.spark.sql.jdbc.JdbcDialect.compileExpression instead.", "3.4.0")
def compileAggregate(aggFunction: AggregateFunc): Option[String] = compileExpression(aggFunction)

/**
* List the user-defined functions in jdbc dialect.
Expand Down
Loading