diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e..01c8e37cf64f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -88,6 +89,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customCheckRules } + /** + * Logical query plan optimizer that takes into account Hive. + */ + override lazy val optimizer: Optimizer = + new SparkOptimizer(catalog, experimentalMethods) { + override def postHocOptimizationBatches: Seq[Batch] = Seq( + Batch("Prune Hive Table Partitions", Once, + new PruneHiveTablePartitions(session)) + ) + + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + /** * Planner that takes into account Hive-specific strategies. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 805b3171cdaa..a18a468786f1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -21,16 +21,17 @@ import java.io.IOException import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, - ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoDir, InsertIntoTable, + LogicalPlan, ScriptTransformation} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} +import org.apache.spark.sql.execution.command.{CommandUtils, CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ @@ -139,6 +140,62 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { } } +/** + * + * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. + */ +case class PruneHiveTablePartitions( + session: SparkSession) extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case filter @ Filter(condition, relation: HiveTableRelation) if relation.isPartitioned => + val predicates = splitConjunctivePredicates(condition) + val normalizedFilters = predicates.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + val partitionSet = AttributeSet(relation.partitionCols) + val pruningPredicates = normalizedFilters.filter { predicate => + !predicate.references.isEmpty && + predicate.references.subsetOf(partitionSet) + } + if (pruningPredicates.nonEmpty && session.sessionState.conf.fallBackToHdfsForStatsEnabled && + session.sessionState.conf.metastorePartitionPruning) { + val prunedPartitions = session.sharedState.externalCatalog.listPartitionsByFilter( + relation.tableMeta.database, + relation.tableMeta.identifier.table, + pruningPredicates, + session.sessionState.conf.sessionLocalTimeZone) + val sizeInBytes = try { + prunedPartitions.map { part => + val rawDataSize = part.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) + val totalSize = part.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + if (rawDataSize.isDefined && rawDataSize.get > 0) { + rawDataSize.get + } else if (totalSize.isDefined && totalSize.get > 0L) { + totalSize.get + } else { + CommandUtils.calculateLocationSize( + session.sessionState, relation.tableMeta.identifier, part.storage.locationUri) + } + }.sum + } catch { + case e: IOException => + logWarning("Failed to get table size from hdfs.", e) + session.sessionState.conf.defaultSizeInBytes + } + val withStats = relation.tableMeta.copy( + stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes)))) + val prunedCatalogRelation = relation.copy(tableMeta = withStats) + val filterExpression = predicates.reduceLeft(And) + Filter(filterExpression, prunedCatalogRelation) + } else { + filter + } + } +} + /** * Replaces generic operations with specific variants that are designed to work with Hive. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9ff9ecf7f367..561dbd75d1a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1261,4 +1261,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + + test("auto converts to broadcast join by size estimate of scanned partitions " + + "for partitioned table") { + withTempView("tempTbl", "largeTbl") { + withTable("partTbl") { + spark.range(0, 1000, 1, 2).selectExpr("id as col1", "id as col2") + .createOrReplaceTempView("tempTbl") + spark.range(0, 100000, 1, 2).selectExpr("id as col1", "id as col2") + .createOrReplaceTempView("largeTbl") + sql("CREATE TABLE partTbl (col1 INT, col2 STRING) " + + "PARTITIONED BY (part1 STRING, part2 INT) STORED AS textfile") + for (part1 <- Seq("a", "b", "c", "d"); part2 <- Seq(1, 2)) { + sql( + s""" + |INSERT OVERWRITE TABLE partTbl PARTITION (part1='$part1',part2='$part2') + |select col1, col2 from tempTbl + """.stripMargin) + } + val query = "select * from largeTbl join partTbl on (largeTbl.col1 = partTbl.col1 " + + "and partTbl.part1 = 'a' and partTbl.part2 = 1)" + withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "8001") { + + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "true") { + val broadcastJoins = + sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j } + assert(broadcastJoins.nonEmpty) + } + + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "false") { + val broadcastJoins = + sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j } + assert(broadcastJoins.isEmpty) + } + } + } + } + } }