Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression =>
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}

/**
* The builder to generate V2 expressions from catalyst expressions.
Expand Down Expand Up @@ -98,45 +98,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpression(child).map(v => new V2Cast(v, dataType))
case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) =>
generateAggregateFunc(aggregateFunction, isDistinct)
case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
case Coalesce(children) => generateExpressionWithName("COALESCE", children)
case Greatest(children) => generateExpressionWithName("GREATEST", children)
case Least(children) => generateExpressionWithName("LEAST", children)
case Rand(child, hideSeed) =>
case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate)
case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate)
case _: Least => generateExpressionWithName("LEAST", expr, isPredicate)
case Rand(_, hideSeed) =>
if (hideSeed) {
Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
} else {
generateExpressionWithName("RAND", Seq(child))
generateExpressionWithName("RAND", expr, isPredicate)
}
case log: Logarithm => generateExpressionWithName("LOG", log.children)
case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
case Log(child) => generateExpressionWithName("LN", Seq(child))
case Exp(child) => generateExpressionWithName("EXP", Seq(child))
case pow: Pow => generateExpressionWithName("POWER", pow.children)
case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
case round: Round => generateExpressionWithName("ROUND", round.children)
case Sin(child) => generateExpressionWithName("SIN", Seq(child))
case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
case Cos(child) => generateExpressionWithName("COS", Seq(child))
case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
case Tan(child) => generateExpressionWithName("TAN", Seq(child))
case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
case Cot(child) => generateExpressionWithName("COT", Seq(child))
case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children)
case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate)
case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate)
case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate)
case _: Log => generateExpressionWithName("LN", expr, isPredicate)
case _: Exp => generateExpressionWithName("EXP", expr, isPredicate)
case _: Pow => generateExpressionWithName("POWER", expr, isPredicate)
case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate)
case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate)
case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate)
case _: Round => generateExpressionWithName("ROUND", expr, isPredicate)
case _: Sin => generateExpressionWithName("SIN", expr, isPredicate)
case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate)
case _: Cos => generateExpressionWithName("COS", expr, isPredicate)
case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate)
case _: Tan => generateExpressionWithName("TAN", expr, isPredicate)
case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate)
case _: Cot => generateExpressionWithName("COT", expr, isPredicate)
case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate)
case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate)
case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate)
case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate)
case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate)
case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate)
case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate)
case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate)
case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate)
case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate)
case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate)
case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate)
case and: And =>
// AND expects predicate
val l = generateExpression(and.left, true)
Expand Down Expand Up @@ -187,57 +187,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
assert(v.isInstanceOf[V2Predicate])
new V2Not(v.asInstanceOf[V2Predicate])
}
case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
case CaseWhen(branches, elseValue) =>
case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate)
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
case caseWhen @ CaseWhen(branches, elseValue) =>
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reserved isPredicate=true for conditions of casewhen

val values = branches.map(_._2).flatMap(generateExpression(_, true))
if (conditions.length == branches.length && values.length == branches.length) {
val values = branches.map(_._2).flatMap(generateExpression(_))
val elseExprOpt = elseValue.flatMap(generateExpression(_))
if (conditions.length == branches.length && values.length == branches.length &&
elseExprOpt.size == elseValue.size) {
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
Seq[V2Expression](c, v)
}
if (elseValue.isDefined) {
elseValue.flatMap(generateExpression(_)).map { v =>
val children = (branchExpressions :+ v).toArray[V2Expression]
// The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
new V2Predicate("CASE_WHEN", children)
}
val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression]
// The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)]
if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate("CASE_WHEN", children))
} else {
// The children looks like [condition1, value1, ..., conditionN, valueN]
Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression]))
Some(new GeneralScalarExpression("CASE_WHEN", children))
}
} else {
None
}
case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate)
case substring: Substring =>
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
Seq(substring.str, substring.pos)
} else {
substring.children
}
generateExpressionWithName("SUBSTRING", children)
case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate)
case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate)
case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate)
case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
generateExpressionWithName("BIT_LENGTH", Seq(child))
generateExpressionWithName("BIT_LENGTH", expr, isPredicate)
case Length(child) if child.dataType.isInstanceOf[StringType] =>
generateExpressionWithName("CHAR_LENGTH", Seq(child))
case concat: Concat => generateExpressionWithName("CONCAT", concat.children)
case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children)
case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)
case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children)
generateExpressionWithName("CHAR_LENGTH", expr, isPredicate)
case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate)
case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate)
case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate)
case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate)
case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate)
case overlay: Overlay =>
val children = if (overlay.len == Literal(-1)) {
Seq(overlay.input, overlay.replace, overlay.pos)
} else {
overlay.children
}
generateExpressionWithName("OVERLAY", children)
case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children)
case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate)
case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate)
case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate)
case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate)
case Second(child, _) =>
generateExpression(child).map(v => new V2Extract("SECOND", v))
case Minute(child, _) =>
Expand Down Expand Up @@ -270,12 +269,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpression(child).map(v => new V2Extract("WEEK", v))
case YearOfWeek(child) =>
generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children)
case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children)
case Crc32(child) => generateExpressionWithName("CRC32", Seq(child))
case Md5(child) => generateExpressionWithName("MD5", Seq(child))
case Sha1(child) => generateExpressionWithName("SHA1", Seq(child))
case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children)
case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate)
case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate)
case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate)
case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate)
case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate)
case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
// TODO supports other expressions
case ApplyFunctionExpression(function, children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
Expand Down Expand Up @@ -380,10 +379,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
}

private def generateExpressionWithName(
v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = {
v2ExpressionName: String,
expr: Expression,
isPredicate: Boolean): Option[V2Expression] = {
generateExpressionWithNameByChildren(
v2ExpressionName, expr.children, expr.dataType, isPredicate)
}

private def generateExpressionWithNameByChildren(
v2ExpressionName: String,
children: Seq[Expression],
dataType: DataType,
isPredicate: Boolean): Option[V2Expression] = {
val childrenExpressions = children.flatMap(generateExpression(_))
if (childrenExpressions.length == children.length) {
Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
} else {
Some(new GeneralScalarExpression(
v2ExpressionName, childrenExpressions.toArray[V2Expression]))
}
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
)
}
}

test("SPARK-47463: Pushed down v2 filter with if expression") {
withTempView("t1") {
spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
.createTempView("t1")
val df = sql("SELECT * FROM t1 WHERE if(i = 1, i, 0) > 0")
val result = df.collect()
assert(result.length == 1)
}
}
}

case class RangeInputPartition(start: Int, end: Int) extends InputPartition
Expand Down