Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -701,19 +701,19 @@ object DataSourceStrategy
protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
if (aggregates.filter.isEmpty) {
aggregates.aggregateFunction match {
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
case aggregate.Min(PushableColumnNoLegacy(name)) =>
Some(new Min(FieldReference.column(name)))
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
case aggregate.Max(PushableColumnNoLegacy(name)) =>
Some(new Max(FieldReference.column(name)))
case count: aggregate.Count if count.children.length == 1 =>
count.children.head match {
// SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
case Literal(_, _) => Some(new CountStar())
case PushableColumnWithoutNestedColumn(name) =>
case PushableColumnNoLegacy(name) =>
Some(new Count(FieldReference.column(name), aggregates.isDistinct))
case _ => None
}
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
case aggregate.Sum(PushableColumnNoLegacy(name), _) =>
Some(new Sum(FieldReference.column(name), aggregates.isDistinct))
case _ => None
}
Expand Down Expand Up @@ -745,6 +745,7 @@ object DataSourceStrategy
*/
abstract class PushableColumnBase {
val nestedPredicatePushdownEnabled: Boolean
val allowDotInAttrName: Boolean

def unapply(e: Expression): Option[String] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
Expand All @@ -756,9 +757,9 @@ abstract class PushableColumnBase {
}

private def extractTopLevelCol(e: Expression): Option[String] = e match {
// Attribute that contains dot "." in name is supported only when nested predicate pushdown
// is enabled.
case a: Attribute if !a.name.contains(".") => Some(a.name)
// To keep backward compatibility, we can't push down filters with column name containing dot,
// as the underlying data source may mistakenly think it's a nested column.
case a: Attribute if !a.name.contains(".") || allowDotInAttrName => Some(a.name)
case _ => None
}

Expand All @@ -782,8 +783,15 @@ object PushableColumn {

object PushableColumnAndNestedColumn extends PushableColumnBase {
override val nestedPredicatePushdownEnabled = true
override val allowDotInAttrName = true
}

object PushableColumnWithoutNestedColumn extends PushableColumnBase {
override val nestedPredicatePushdownEnabled = false
override val allowDotInAttrName = false
}

object PushableColumnNoLegacy extends PushableColumnBase {
override val nestedPredicatePushdownEnabled = false
override val allowDotInAttrName = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnNoLegacy}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -86,7 +84,7 @@ object PushDownUtils extends PredicateHelper {
groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
case PushableColumnNoLegacy(name) =>
Some(FieldReference.column(name).asInstanceOf[FieldReference])
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.executeUpdate()

conn.prepareStatement(
"CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate()
"CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate()

// scalastyle:off
conn.prepareStatement(
Expand Down Expand Up @@ -480,15 +481,36 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel


test("column name with composite field") {
checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2)))
val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
df.queryExecution.optimizedPlan.collect {
checkAnswer(sql("SELECT `dept id`, `dept.id` FROM h2.test.dept"), Seq(Row(1, 1), Row(2, 1)))

val df1 = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
df1.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(`dept id`)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
checkKeywordsExistsInExplain(df1, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(2)))
checkAnswer(df1, Seq(Row(2)))

val df2 = sql("SELECT `dept.id`, COUNT(`dept id`) FROM h2.test.dept GROUP BY `dept.id`")
df2.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment = Seq(
"PushedAggregates: [COUNT(`dept id`)]",
"PushedGroupby: [`dept.id`]")
checkKeywordsExistsInExplain(df2, expected_plan_fragment: _*)
}
checkAnswer(df2, Seq(Row(1, 2)))

val df3 = sql("SELECT `dept id`, COUNT(`dept.id`) FROM h2.test.dept GROUP BY `dept id`")
df3.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment = Seq(
"PushedAggregates: [COUNT(`dept.id`)]",
"PushedGroupby: [`dept id`]")
checkKeywordsExistsInExplain(df3, expected_plan_fragment: _*)
}
checkAnswer(df3, Seq(Row(1, 1), Row(2, 1)))
}

test("column name with non-ascii") {
Expand Down