diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 020dd79f8f0d7..de8eaa2d60ead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -32,6 +32,7 @@ import org.apache.spark.annotation.{Evolving, Stable} sealed abstract class Filter { /** * List of columns that are referenced by this filter. + * Note that, if a column contains `dots` in name, it will be quoted to avoid confusion. * @since 2.1.0 */ def references: Array[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1641b660a271d..175325aa1859a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -652,10 +652,11 @@ object DataSourceStrategy { */ object PushableColumn { def unapply(e: Expression): Option[String] = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ def helper(e: Expression) = e match { case a: Attribute => Some(a.name) case _ => None } - helper(e) + helper(e).map(quoteIfNeeded) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index e673309188756..2cf4903a16acf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.execution.datasources.orc +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} +import org.apache.spark.sql.types.StructType /** * Methods that can be shared when upgrading the built-in Hive. @@ -45,4 +47,11 @@ trait OrcFiltersBase { case _: AtomicType => true case _ => false } + + /** + * The key of the dataTypeMap will be quoted if it contains `dots`. + */ + protected[sql] def quotedDataTypeMap(schema: StructType): Map[String, DataType] = { + schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07065018343cf..bffdc672e5209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -34,6 +34,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources import org.apache.spark.unsafe.types.UTF8String @@ -55,7 +56,7 @@ class ParquetFilters( // and it does not support to create filters for them. val primitiveFields = schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetField(f.getName, + quoteIfNeeded(f.getName) -> ParquetField(f.getName, ParquetSchemaType(f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 1421ffd8b6de4..17223ab87f3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters @@ -59,7 +60,7 @@ case class OrcScanBuilder( // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray } filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 7bd3213b378ce..f690d76ab4793 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -26,15 +26,19 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( - 'cint.int + 'cint.int, + Symbol("c.int").int ).zip(Seq( - "cint" + "cint", + "`c.int`" // single level field that contains `dot` in name )) val attrStrs = Seq( - 'cstr.string + 'cstr.string, + Symbol("c.str").string ).zip(Seq( - "cstr" + "cstr", + "`c.str`" // single level field that contains `dot` in name )) test("translate simple expression") { attrInts.zip(attrStrs) diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index b9cbc484e1fc1..d21d41d523758 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -65,7 +65,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = quotedDataTypeMap(schema) // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => @@ -222,48 +222,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6e9e592be13be..a4e44263e7749 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -24,7 +24,6 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -65,7 +64,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = quotedDataTypeMap(schema) // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => @@ -222,48 +221,39 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + case IsNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + Some(builder.startNot().isNull(name, getType(name)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - val quotedName = quoteIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(quotedName, getType(attribute), + case In(name, values) if isSearchableType(dataTypeMap(name)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index cd1bffb6b7ab7..b7f37f8f134c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.execution.datasources.orc.{OrcFilters => DatasourceOrcFilters} import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.hive.HiveUtils @@ -73,7 +74,7 @@ private[orc] object OrcFilters extends Logging { if (HiveUtils.isHive23) { DatasourceOrcFilters.createFilter(schema, filters).asInstanceOf[Option[SearchArgument]] } else { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction =>