diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a6ce8c61199cc..56dd201220b00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -701,19 +701,19 @@ object DataSourceStrategy protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { if (aggregates.filter.isEmpty) { aggregates.aggregateFunction match { - case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => + case aggregate.Min(PushableColumnNoLegacy(name)) => Some(new Min(FieldReference.column(name))) - case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => + case aggregate.Max(PushableColumnNoLegacy(name)) => Some(new Max(FieldReference.column(name))) case count: aggregate.Count if count.children.length == 1 => count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table case Literal(_, _) => Some(new CountStar()) - case PushableColumnWithoutNestedColumn(name) => + case PushableColumnNoLegacy(name) => Some(new Count(FieldReference.column(name), aggregates.isDistinct)) case _ => None } - case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + case aggregate.Sum(PushableColumnNoLegacy(name), _) => Some(new Sum(FieldReference.column(name), aggregates.isDistinct)) case _ => None } @@ -745,6 +745,7 @@ object DataSourceStrategy */ abstract class PushableColumnBase { val nestedPredicatePushdownEnabled: Boolean + val allowDotInAttrName: Boolean def unapply(e: Expression): Option[String] = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper @@ -756,9 +757,9 @@ abstract class PushableColumnBase { } private def extractTopLevelCol(e: Expression): Option[String] = e match { - // Attribute that contains dot "." in name is supported only when nested predicate pushdown - // is enabled. - case a: Attribute if !a.name.contains(".") => Some(a.name) + // To keep backward compatibility, we can't push down filters with column name containing dot, + // as the underlying data source may mistakenly think it's a nested column. + case a: Attribute if !a.name.contains(".") || allowDotInAttrName => Some(a.name) case _ => None } @@ -782,8 +783,15 @@ object PushableColumn { object PushableColumnAndNestedColumn extends PushableColumnBase { override val nestedPredicatePushdownEnabled = true + override val allowDotInAttrName = true } object PushableColumnWithoutNestedColumn extends PushableColumnBase { override val nestedPredicatePushdownEnabled = false + override val allowDotInAttrName = false +} + +object PushableColumnNoLegacy extends PushableColumnBase { + override val nestedPredicatePushdownEnabled = false + override val allowDotInAttrName = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 434abdd8570db..bb3f1091b9e1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnNoLegacy} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -86,7 +84,7 @@ object PushDownUtils extends PredicateHelper { groupBy: Seq[Expression]): Option[Aggregation] = { def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => + case PushableColumnNoLegacy(name) => Some(FieldReference.column(name).asInstanceOf[FieldReference]) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 20b4706d28ab5..ad86d21d73331 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -80,9 +80,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .executeUpdate() conn.prepareStatement( - "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate() + "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate() // scalastyle:off conn.prepareStatement( @@ -480,15 +481,36 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("column name with composite field") { - checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2))) - val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") - df.queryExecution.optimizedPlan.collect { + checkAnswer(sql("SELECT `dept id`, `dept.id` FROM h2.test.dept"), Seq(Row(1, 1), Row(2, 1))) + + val df1 = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") + df1.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [COUNT(`dept id`)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df1, expected_plan_fragment) } - checkAnswer(df, Seq(Row(2))) + checkAnswer(df1, Seq(Row(2))) + + val df2 = sql("SELECT `dept.id`, COUNT(`dept id`) FROM h2.test.dept GROUP BY `dept.id`") + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = Seq( + "PushedAggregates: [COUNT(`dept id`)]", + "PushedGroupby: [`dept.id`]") + checkKeywordsExistsInExplain(df2, expected_plan_fragment: _*) + } + checkAnswer(df2, Seq(Row(1, 2))) + + val df3 = sql("SELECT `dept id`, COUNT(`dept.id`) FROM h2.test.dept GROUP BY `dept id`") + df3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = Seq( + "PushedAggregates: [COUNT(`dept.id`)]", + "PushedGroupby: [`dept id`]") + checkKeywordsExistsInExplain(df3, expected_plan_fragment: _*) + } + checkAnswer(df3, Seq(Row(1, 1), Row(2, 1))) } test("column name with non-ascii") {