diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index e7d1ed34f5ab..670400f8a973 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -65,6 +66,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log t.identifier.table.toLowerCase) } + /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ + private val tableCreationLocks = Striped.lazyWeakLock(100) + + /** Acquires a lock on the table cache for the duration of `f`. */ + private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = { + val lock = tableCreationLocks.get(tableName) + lock.lock() + try f finally { + lock.unlock() + } + } + /** A cache of Spark SQL data source tables that have been accessed. */ protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = { val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { @@ -274,77 +287,81 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log partitionPaths } - val cached = getCached( - tableIdentifier, - paths, - metastoreRelation, - metastoreSchema, - fileFormatClass, - bucketSpec, - Some(partitionSpec)) - - val hadoopFsRelation = cached.getOrElse { - val fileCatalog = new MetaStorePartitionedTableFileCatalog( - sparkSession, - new Path(metastoreRelation.catalogTable.storage.locationUri.get), - partitionSpec) - - val inferredSchema = if (fileType.equals("parquet")) { - val inferredSchema = - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) - inferredSchema.map { inferred => - ParquetFileFormat.mergeMetastoreParquetSchema(metastoreSchema, inferred) - }.getOrElse(metastoreSchema) - } else { - defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get - } + withTableCreationLock(tableIdentifier, { + val cached = getCached( + tableIdentifier, + paths, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + Some(partitionSpec)) + + val hadoopFsRelation = cached.getOrElse { + val fileCatalog = new MetaStorePartitionedTableFileCatalog( + sparkSession, + new Path(metastoreRelation.catalogTable.storage.locationUri.get), + partitionSpec) + + val inferredSchema = if (fileType.equals("parquet")) { + val inferredSchema = + defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) + inferredSchema.map { inferred => + ParquetFileFormat.mergeMetastoreParquetSchema(metastoreSchema, inferred) + }.getOrElse(metastoreSchema) + } else { + defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get + } - val relation = HadoopFsRelation( - location = fileCatalog, - partitionSchema = partitionSchema, - dataSchema = inferredSchema, - bucketSpec = bucketSpec, - fileFormat = defaultSource, - options = options)(sparkSession = sparkSession) - - val created = LogicalRelation( - relation, - metastoreTableIdentifier = - Some(TableIdentifier(tableIdentifier.name, Some(tableIdentifier.database)))) - cachedDataSourceTables.put(tableIdentifier, created) - created - } + val relation = HadoopFsRelation( + location = fileCatalog, + partitionSchema = partitionSchema, + dataSchema = inferredSchema, + bucketSpec = bucketSpec, + fileFormat = defaultSource, + options = options)(sparkSession = sparkSession) + + val created = LogicalRelation( + relation, + metastoreTableIdentifier = + Some(TableIdentifier(tableIdentifier.name, Some(tableIdentifier.database)))) + cachedDataSourceTables.put(tableIdentifier, created) + created + } - hadoopFsRelation + hadoopFsRelation + }) } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - val cached = getCached(tableIdentifier, - paths, - metastoreRelation, - metastoreSchema, - fileFormatClass, - bucketSpec, - None) - val logicalRelation = cached.getOrElse { - val created = - LogicalRelation( - DataSource( - sparkSession = sparkSession, - paths = paths, - userSpecifiedSchema = Some(metastoreRelation.schema), - bucketSpec = bucketSpec, - options = options, - className = fileType).resolveRelation(), - metastoreTableIdentifier = - Some(TableIdentifier(tableIdentifier.name, Some(tableIdentifier.database)))) - - - cachedDataSourceTables.put(tableIdentifier, created) - created - } + withTableCreationLock(tableIdentifier, { + val cached = getCached(tableIdentifier, + paths, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + None) + val logicalRelation = cached.getOrElse { + val created = + LogicalRelation( + DataSource( + sparkSession = sparkSession, + paths = paths, + userSpecifiedSchema = Some(metastoreRelation.schema), + bucketSpec = bucketSpec, + options = options, + className = fileType).resolveRelation(), + metastoreTableIdentifier = + Some(TableIdentifier(tableIdentifier.name, Some(tableIdentifier.database)))) + + + cachedDataSourceTables.put(tableIdentifier, created) + created + } - logicalRelation + logicalRelation + }) } result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) }