Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
Expand Down Expand Up @@ -88,6 +89,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
customCheckRules
}

/**
* Logical query plan optimizer that takes into account Hive.
*/
override lazy val optimizer: Optimizer =
new SparkOptimizer(catalog, experimentalMethods) {
override def postHocOptimizationBatches: Seq[Batch] = Seq(
Batch("Prune Hive Table Partitions", Once,
new PruneHiveTablePartitions(session))
)

override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
}

/**
* Planner that takes into account Hive-specific strategies.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import java.io.IOException
import java.util.Locale

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.StatsSetupConst

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan,
ScriptTransformation}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoDir, InsertIntoTable,
LogicalPlan, ScriptTransformation}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.command.{CommandUtils, CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.hive.execution._
Expand Down Expand Up @@ -139,6 +140,62 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
}
}

/**
*
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
case class PruneHiveTablePartitions(
session: SparkSession) extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case filter @ Filter(condition, relation: HiveTableRelation) if relation.isPartitioned =>
val predicates = splitConjunctivePredicates(condition)
val normalizedFilters = predicates.map { e =>
e transform {
case a: AttributeReference =>
a.withName(relation.output.find(_.semanticEquals(a)).get.name)
}
}
val partitionSet = AttributeSet(relation.partitionCols)
val pruningPredicates = normalizedFilters.filter { predicate =>
!predicate.references.isEmpty &&
predicate.references.subsetOf(partitionSet)
}
if (pruningPredicates.nonEmpty && session.sessionState.conf.fallBackToHdfsForStatsEnabled &&
session.sessionState.conf.metastorePartitionPruning) {
val prunedPartitions = session.sharedState.externalCatalog.listPartitionsByFilter(
relation.tableMeta.database,
relation.tableMeta.identifier.table,
pruningPredicates,
session.sessionState.conf.sessionLocalTimeZone)
val sizeInBytes = try {
prunedPartitions.map { part =>
val rawDataSize = part.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong)
val totalSize = part.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong)
if (rawDataSize.isDefined && rawDataSize.get > 0) {
rawDataSize.get
} else if (totalSize.isDefined && totalSize.get > 0L) {
totalSize.get
} else {
CommandUtils.calculateLocationSize(
session.sessionState, relation.tableMeta.identifier, part.storage.locationUri)
}
}.sum
} catch {
case e: IOException =>
logWarning("Failed to get table size from hdfs.", e)
session.sessionState.conf.defaultSizeInBytes
}
val withStats = relation.tableMeta.copy(
stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes))))
val prunedCatalogRelation = relation.copy(tableMeta = withStats)
val filterExpression = predicates.reduceLeft(And)
Filter(filterExpression, prunedCatalogRelation)
} else {
filter
}
}
}

/**
* Replaces generic operations with specific variants that are designed to work with Hive.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1261,4 +1261,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}

}

test("auto converts to broadcast join by size estimate of scanned partitions " +
"for partitioned table") {
withTempView("tempTbl", "largeTbl") {
withTable("partTbl") {
spark.range(0, 1000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("tempTbl")
spark.range(0, 100000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("largeTbl")
sql("CREATE TABLE partTbl (col1 INT, col2 STRING) " +
"PARTITIONED BY (part1 STRING, part2 INT) STORED AS textfile")
for (part1 <- Seq("a", "b", "c", "d"); part2 <- Seq(1, 2)) {
sql(
s"""
|INSERT OVERWRITE TABLE partTbl PARTITION (part1='$part1',part2='$part2')
|select col1, col2 from tempTbl
""".stripMargin)
}
val query = "select * from largeTbl join partTbl on (largeTbl.col1 = partTbl.col1 " +
"and partTbl.part1 = 'a' and partTbl.part2 = 1)"
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "8001") {

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "true") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.nonEmpty)
}

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "false") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.isEmpty)
}
}
}
}
}
}