diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 9c4b8a5819a3..db2936099040 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -155,6 +155,16 @@ private[spark] object HiveUtils extends Logging { .booleanConf .createWithDefault(true) + val HIVE_TABLE_SCAN_MAX_PARALLELISM = buildConf("spark.sql.hive.tableScan.maxParallelism") + .doc("When reading Hive partitioned table, the default parallelism is the sum of Hive " + + "partition RDDs' parallelism. For Hive table of many partitions with many files, " + + "the parallelism could be very big and not good for Spark job scheduling. This optional " + + "config can set a maximum parallelism for reading Hive partitioned table. If the result " + + "RDD of reading such table is larger than this value, Spark will reduce the partition " + + "number by doing a coalesce on the RDD.") + .intConf + .createOptional + /** * The version of the hive client that will be used to communicate with the metastore. Note that * this does not necessarily need to be the same version of Hive that is used internally by diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b1182b271912..7977da339107 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -267,7 +267,17 @@ class HadoopTableReader( if (hivePartitionRDDs.size == 0) { new EmptyRDD[InternalRow](sparkSession.sparkContext) } else { - new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + val unionRDD = new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + val partNums = unionRDD.partitions.length + val maxPartNums = SQLConf.get.getConf(HiveUtils.HIVE_TABLE_SCAN_MAX_PARALLELISM) + if (maxPartNums.isDefined && partNums > maxPartNums.get) { + logWarning(s"Union of Hive partitions' HadoopRDDs has ${partNums} partitions " + + "which exceeds the config `spark.sql.hive.tableScan.maxParallelism`. " + + s"Coalesces the Union RDD to ${maxPartNums.get} partitions.") + unionRDD.coalesce(maxPartNums.get) + } else { + unionRDD + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 3f9bb8de42e0..cae78aaa0811 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -187,6 +188,35 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH } } + test("HiveTableScanExec should not increase data parallelism") { + withSQLConf(HiveUtils.HIVE_TABLE_SCAN_MAX_PARALLELISM.key -> "1") { + val view = "src" + withTempView(view) { + spark.range(1, 5).createOrReplaceTempView(view) + val table = "hive_tbl_part" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (id int) + |PARTITIONED BY (a int, b int) + """.stripMargin) + sql( + s""" + |FROM $view v + |INSERT INTO TABLE $table + |PARTITION (a=1, b=2) + |SELECT v.id + |INSERT INTO TABLE $table + |PARTITION (a=2, b=3) + |SELECT v.id + """.stripMargin) + val scanRdd = getHiveTableScanExec(s"SELECT * FROM $table").execute() + assert(scanRdd.partitions.length == 1) + } + } + } + } + private def getHiveTableScanExec(query: String): HiveTableScanExec = { sql(query).queryExecution.sparkPlan.collectFirst { case p: HiveTableScanExec => p