Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hive

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.google.common.util.concurrent.Striped
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
Expand All @@ -32,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa
import org.apache.spark.sql.hive.orc.OrcFileFormat
import org.apache.spark.sql.types._


/**
* Legacy catalog for interacting with the Hive metastore.
*
Expand All @@ -53,6 +53,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
tableIdent.table.toLowerCase)
}

/** These locks guard against multiple attempts to instantiate a table, which wastes memory. */
private val tableCreationLocks = Striped.lazyWeakLock(100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "These locks guard against multiple attempts to instantiate a table, which wastes memory."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done


/** Acquires a lock on the table cache for the duration of `f`. */
private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = {
val lock = tableCreationLocks.get(tableName)
lock.lock()
try f finally {
lock.unlock()
}
}

/** A cache of Spark SQL data source tables that have been accessed. */
protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = {
val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() {
Expand Down Expand Up @@ -209,72 +221,76 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}

val cached = getCached(
tableIdentifier,
rootPaths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
Some(partitionSchema))

val logicalRelation = cached.getOrElse {
val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong
val fileCatalog = {
val catalog = new CatalogFileIndex(
sparkSession, metastoreRelation.catalogTable, sizeInBytes)
if (lazyPruningEnabled) {
catalog
} else {
catalog.filterPartitions(Nil) // materialize all the partitions in memory
withTableCreationLock(tableIdentifier, {
val cached = getCached(
tableIdentifier,
rootPaths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
Some(partitionSchema))

val logicalRelation = cached.getOrElse {
val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong
val fileCatalog = {
val catalog = new CatalogFileIndex(
sparkSession, metastoreRelation.catalogTable, sizeInBytes)
if (lazyPruningEnabled) {
catalog
} else {
catalog.filterPartitions(Nil) // materialize all the partitions in memory
}
}
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))

val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)

val created = LogicalRelation(relation,
catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))

val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)

val created = LogicalRelation(relation, catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}

logicalRelation
logicalRelation
})
} else {
val rootPath = metastoreRelation.hiveQlTable.getDataLocation

val cached = getCached(tableIdentifier,
Seq(rootPath),
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
None)
val logicalRelation = cached.getOrElse {
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec,
options = options,
className = fileType).resolveRelation(),
withTableCreationLock(tableIdentifier, {
val cached = getCached(tableIdentifier,
Seq(rootPath),
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
None)
val logicalRelation = cached.getOrElse {
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec,
options = options,
className = fileType).resolveRelation(),
catalogTable = Some(metastoreRelation.catalogTable))

cachedDataSourceTables.put(tableIdentifier, created)
created
}
cachedDataSourceTables.put(tableIdentifier, created)
created
}

logicalRelation
logicalRelation
})
}
result.copy(expectedOutputAttributes = Some(metastoreRelation.output))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hive

import java.io.File
import java.util.concurrent.{Executors, TimeUnit}

import org.scalatest.BeforeAndAfterEach

Expand Down Expand Up @@ -352,4 +353,34 @@ class PartitionedTablePerfStatsSuite
}
}
}

test("SPARK-18700: table loaded only once even when resolved concurrently") {
withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
withTable("test") {
withTempDir { dir =>
HiveCatalogMetrics.reset()
setupPartitionedHiveTable("test", dir, 50)
// select the table in multi-threads
val executorPool = Executors.newFixedThreadPool(10)
(1 to 10).map(threadId => {
val runnable = new Runnable {
override def run(): Unit = {
spark.sql("select * from test where partCol1 = 999").count()
}
}
executorPool.execute(runnable)
None
})
executorPool.shutdown()
executorPool.awaitTermination(30, TimeUnit.SECONDS)
// check the cache hit, we use the metric of METRIC_FILES_DISCOVERED and
// METRIC_PARALLEL_LISTING_JOB_COUNT to check this, while the lock take effect,
// only one thread can really do the build, so the listing job count is 2, the other
// one is cache.load func. Also METRIC_FILES_DISCOVERED is $partition_num * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is wrong. The extra counts are from the DataFrameWriter's save() API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on a fix to avoid the useless filesystem scan caused by the save() API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile Xiao fixed this in #16481

assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 100)
assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 2)
}
}
}
}
}