Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ private[spark] object LogKeys {
case object HIVE_METASTORE_VERSION extends LogKey
case object HIVE_OPERATION_STATE extends LogKey
case object HIVE_OPERATION_TYPE extends LogKey
case object HMS_CURRENT_BATCH_SIZE extends LogKey
case object HMS_INITIAL_BATCH_SIZE extends LogKey
case object HOST extends LogKey
case object HOSTS extends LogKey
case object HOST_LOCAL_BLOCKS_SIZE extends LogKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5199,6 +5199,29 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize")
.internal()
.doc("This setting defines the batch size for fetching metadata partitions from the" +
"Hive Metastore. A value of -1 disables batching by default. To enable batching," +
"specify a positive integer, which will determine the batch size for partition fetching."
)
.version("4.0.0")
.intConf
.createWithDefault(-1)

val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf(
"spark.sql.metastore.partition.batch.retry.count")
.internal()
.doc(
"This setting specifies the number of retries for fetching partitions from the metastore" +
"in case of failure to fetch batch metadata. This retry mechanism is applicable only" +
"when HMS_BATCH_SIZE is enabled. It defines the count for the number of " +
"retries to be done."
)
.version("4.0.0")
.intConf
.createWithDefault(3)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -6177,6 +6200,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE)

def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Method}
import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap}
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

Expand All @@ -37,7 +38,7 @@ import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorF
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.{CONFIG, CONFIG2, CONFIG3}
import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
Expand Down Expand Up @@ -389,6 +390,70 @@ private[client] class Shim_v2_0 extends Shim with Logging {
partitions.asScala.toSeq
}

private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = {
val partitionNames = hive.getPartitionNames(
table.getDbName, table.getTableName, -1).asScala.toSeq
(partitionNames.length, partitionNames)
}

private def getPartitionsInBatches(
hive: Hive,
table: Table,
initialBatchSize: Int,
partNames: Seq[String]): java.util.Collection[Partition] = {
val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount
val decayingFactor = 2

if (initialBatchSize <= 0) {
throw new IllegalArgumentException(s"Invalid batch size $initialBatchSize provided " +
s"for fetching partitions.Batch size must be greater than 0")
}

if (maxRetries < 0) {
throw new IllegalArgumentException(s"Invalid number of maximum retries $maxRetries " +
s"provided for fetching partitions.It must be a non-negative integer value")
}

logInfo(log"Breaking your request into small batches" +
log" of ${MDC(LogKeys.HMS_INITIAL_BATCH_SIZE, initialBatchSize)}.")

var batchSize = initialBatchSize
val processedPartitions = mutable.ListBuffer[Partition]()
var retryCount = 0
var index = 0

def getNextBatchSize(): Int = {
val currentBatchSize = batchSize
batchSize = (batchSize / decayingFactor) max 1
currentBatchSize
}

var currentBatchSize = getNextBatchSize()
var partitions: java.util.Collection[Partition] = null

while (index < partNames.size && retryCount <= maxRetries) {
val batch = partNames.slice(index, index + currentBatchSize)

try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
processedPartitions ++= partitions.asScala
index += batch.size
} catch {
case ex: Exception =>
logWarning(s"Caught exception while fetching partitions for batch, attempting retry.", ex)
retryCount += 1
currentBatchSize = getNextBatchSize()
logInfo(log"Further reducing batch size to " +
log"${MDC(LogKeys.HMS_CURRENT_BATCH_SIZE, currentBatchSize)}.")
if (retryCount > maxRetries) {
logError(s"Failed to fetch partitions for the request. Retries count exceeded.")
}
}
}

processedPartitions.asJava
}

private def prunePartitionsFastFallback(
hive: Hive,
table: Table,
Expand All @@ -406,11 +471,19 @@ private[client] class Shim_v2_0 extends Shim with Logging {
}
}

val batchSize = SQLConf.get.getHiveMetaStoreBatchSize

if (!SQLConf.get.metastorePartitionPruningFastFallback ||
predicates.isEmpty ||
predicates.exists(hasTimeZoneAwareExpression)) {
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
hive.getAllPartitionsOf(table)
if(count < batchSize || batchSize == -1) {
hive.getAllPartitionsOf(table)
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
} else {
try {
val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
Expand Down Expand Up @@ -442,8 +515,14 @@ private[client] class Shim_v2_0 extends Shim with Logging {
case ex: HiveException if ex.getCause.isInstanceOf[MetaException] =>
logWarning("Caught Hive MetaException attempting to get partition metadata by " +
"filter from client side. Falling back to fetching all partition metadata", ex)
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
hive.getAllPartitionsOf(table)
if(count < batchSize || batchSize == -1) {
hive.getAllPartitionsOf(table)
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,40 @@ class HivePartitionFilteringSuite(version: String)
}
}

test("getPartitionsByFilter: getPartitionsInBatches") {
var filteredPartitions: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq()

withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") {
filteredPartitions = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") {
filteredPartitionsNoBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") {
filteredPartitionsHighBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}

assert(filteredPartitions.size == filteredPartitionsNoBatch.size)
assert(filteredPartitions.size == filteredPartitionsHighBatch.size)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsNoBatch.map(_.spec.toSet).toSet)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsHighBatch.map(_.spec.toSet).toSet)
}

private def testMetastorePartitionFiltering(
filterExpr: Expression,
expectedDs: Seq[Int],
Expand Down