diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index b9be7a7545ef..dde2a9b708c7 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1086,6 +1086,13 @@ the following case-sensitive options:
+
+ maxConnections |
+
+ The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing.
+ |
+
+
isolationLevel |
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 7f419b5788c4..d416eec6ddae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -122,6 +122,11 @@ class JDBCOptions(
case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ
case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE
}
+ // the maximum number of connections
+ val maxConnections = parameters.get(JDBC_MAX_CONNECTIONS).map(_.toInt)
+ require(maxConnections.isEmpty || maxConnections.get > 0,
+ s"Invalid value `${maxConnections.get}` for parameter `$JDBC_MAX_CONNECTIONS`. " +
+ "The minimum value is 1.")
}
object JDBCOptions {
@@ -144,4 +149,5 @@ object JDBCOptions {
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
+ val JDBC_MAX_CONNECTIONS = newOption("maxConnections")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 41edb6511c2c..cdc3c99daa1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -667,7 +667,14 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
- df.foreachPartition(iterator => savePartition(
+ val maxConnections = options.maxConnections
+ val repartitionedDF =
+ if (maxConnections.isDefined && maxConnections.get < df.rdd.getNumPartitions) {
+ df.coalesce(maxConnections.get)
+ } else {
+ df
+ }
+ repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e3d3c6c3a887..5795b4d860cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -312,4 +312,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
.options(properties.asScala)
.save()
}
+
+ test("SPARK-18413: Add `maxConnections` JDBCOption") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val e = intercept[IllegalArgumentException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .option("url", url1)
+ .option(s"${JDBCOptions.JDBC_MAX_CONNECTIONS}", "0")
+ .save()
+ }.getMessage
+ assert(e.contains("Invalid value `0` for parameter `maxConnections`. The minimum value is 1"))
+ }
}
|