From 732e592ccaab9b8105fab8dbd57c66eff067a4f4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jul 2016 12:36:30 -0700 Subject: [PATCH 1/5] jdbc. --- .../execution/datasources/DataSource.scala | 6 ++--- .../jdbc/JdbcRelationProvider.scala | 26 ++++++++++++++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 +++++++++ .../spark/sql/jdbc/JDBCWriteSuite.scala | 15 +++++++++++ 4 files changed, 52 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 6dc27c19521e..404a20eb7b01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -423,7 +423,6 @@ case class DataSource( if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sparkSession.sqlContext, mode, options, data) @@ -485,12 +484,11 @@ case class DataSource( data.logicalPlan, mode) sparkSession.sessionState.executePlan(plan).toRdd + // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. + copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } - - // We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 106ed1d44010..6acdcd997acd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import org.apache.spark.Partition +import org.apache.spark.sql._ +import org.apache.spark.sql.sources.{CreatableRelationProvider, BaseRelation, DataSourceRegister, RelationProvider} -class JdbcRelationProvider extends RelationProvider with DataSourceRegister { +class JdbcRelationProvider + extends RelationProvider + with CreatableRelationProvider + with DataSourceRegister { override def shortName(): String = "jdbc" @@ -52,4 +56,20 @@ class JdbcRelationProvider extends RelationProvider with DataSourceRegister { parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val jdbcOptions = new JDBCOptions(parameters) + val parts = Array.empty[Partition] + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + data.write + .mode(mode) + .jdbc(jdbcOptions.url, jdbcOptions.table, properties) + JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 11e66ad08009..ccfb1bc0b0a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -370,6 +370,18 @@ class JDBCSuite extends SparkFunSuite } } + test("load API") { + val dfUsingOption = + spark.read + .option("url", url) + .option("dbtable", "(SELECT * FROM TEST.PEOPLE)") + .option("user", "testUser") + .option("password", "testPass") + .format("jdbc") + .load() + assert(dfUsingOption.count == 3) + } + test("Partitioning via JDBCPartitioningInfo API") { assert( spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2c6449fa6870..7999f6132faf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -155,6 +155,21 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } + test("jdbc save") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + import scala.collection.JavaConverters._ + + // df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + } + test("Incompatible INSERT to append") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) From 1b4db1ad59ec59275f6d5ac9f8479b16072e9bc8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jul 2016 14:17:15 -0700 Subject: [PATCH 2/5] fix --- .../jdbc/JdbcRelationProvider.scala | 55 +++++++++++------- .../spark/sql/jdbc/JDBCWriteSuite.scala | 56 ++++++++++++++----- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 6acdcd997acd..a3acf0323e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -21,7 +21,7 @@ import java.util.Properties import org.apache.spark.Partition import org.apache.spark.sql._ -import org.apache.spark.sql.sources.{CreatableRelationProvider, BaseRelation, DataSourceRegister, RelationProvider} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends RelationProvider @@ -35,41 +35,56 @@ class JdbcRelationProvider sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - if (jdbcOptions.partitionColumn != null - && (jdbcOptions.lowerBound == null - || jdbcOptions.upperBound == null - || jdbcOptions.numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (jdbcOptions.partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - jdbcOptions.partitionColumn, - jdbcOptions.lowerBound.toLong, - jdbcOptions.upperBound.toLong, - jdbcOptions.numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) + val parts = buildJDBCPartition(jdbcOptions) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } - + /** + * Save the DataFrame to the destination and return a relation with the given parameters based on + * the contents of the given DataFrame. + */ override def createRelation( sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - val parts = Array.empty[Partition] + val parts = buildJDBCPartition(jdbcOptions) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) data.write .mode(mode) .jdbc(jdbcOptions.url, jdbcOptions.table, properties) + JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /** + * Build Partitions based on the user-provided options: + * - "partitionColumn": the column used to partition + * - "lowerBound": the lower bound of partition column + * - "upperBound": the upper bound of the partition column + * - "numPartitions": the number of partitions + */ + private def buildJDBCPartition(jdbcOptions: JDBCOptions): Array[Partition] = { + if (jdbcOptions.partitionColumn != null + && (jdbcOptions.lowerBound == null + || jdbcOptions.upperBound == null + || jdbcOptions.numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (jdbcOptions.partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + jdbcOptions.partitionColumn, + jdbcOptions.lowerBound.toLong, + jdbcOptions.upperBound.toLong, + jdbcOptions.numPartitions.toInt) + } + JDBCRelation.columnPartition(partitionInfo) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 7999f6132faf..82c1ce5c3111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -155,21 +155,6 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } - test("jdbc save") { - val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - import scala.collection.JavaConverters._ - - // df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) - df2.write.mode(SaveMode.Overwrite).format("jdbc") - .option("url", url1) - .option("dbtable", "TEST.TRUNCATETEST") - .options(properties.asScala) - .save() - assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) - } - test("Incompatible INSERT to append") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) @@ -192,4 +177,45 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save API") { + import scala.collection.JavaConverters._ + + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + } + + test("save API - error handling") { + import scala.collection.JavaConverters._ + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + var e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' not specified")) + + e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.TRUNCATETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } } From 50c9de85fd76a87caa4eb432c74cb9717d48f9e6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jul 2016 14:18:48 -0700 Subject: [PATCH 3/5] revert --- .../org/apache/spark/sql/execution/datasources/DataSource.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 404a20eb7b01..f572b93991e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -423,6 +423,7 @@ case class DataSource( if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } + providingClass.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sparkSession.sqlContext, mode, options, data) From 2e799ce86652bc5c03d21fdbf0a11fab20b37c39 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 1 Sep 2016 14:27:29 -0700 Subject: [PATCH 4/5] address comments. --- .../apache/spark/sql/DataFrameWriter.scala | 60 ++----------------- .../jdbc/JdbcRelationProvider.scala | 11 ++-- .../datasources/jdbc/JdbcUtils.scala | 59 +++++++++++++++--- 3 files changed, 60 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index a9049a60f25e..3898f0e74a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -414,62 +414,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - - // to add required options like URL and dbtable - val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) - val jdbcOptions = new JDBCOptions(params) - val jdbcUrl = jdbcOptions.url - val jdbcTable = jdbcOptions.table - - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } - // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() - - try { - var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) - - if (mode == SaveMode.Ignore && tableExists) { - return - } - - if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $jdbcTable already exists.") - } - - if (mode == SaveMode.Overwrite && tableExists) { - if (jdbcOptions.isTruncate && - JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { - JdbcUtils.truncateTable(conn, jdbcTable) - } else { - JdbcUtils.dropTable(conn, jdbcTable) - tableExists = false - } - } - - // Create the table if the table didn't exist. - if (!tableExists) { - val schema = JdbcUtils.schemaString(df, jdbcUrl) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() - } - } - } finally { - conn.close() - } - - JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) + // to add connectionProperties and required options like URL and dbtable + val params = + extraOptions.toMap ++ connectionProperties.asScala ++ Map("url" -> url, "dbtable" -> table) + JdbcUtils.saveTable(mode, params, df) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index a3acf0323e68..d808717359cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -52,12 +52,9 @@ class JdbcRelationProvider data: DataFrame): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) val parts = buildJDBCPartition(jdbcOptions) - val properties = new Properties() // Additional properties that we will pass to getConnection + val properties = new Properties() parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - data.write - .mode(mode) - .jdbc(jdbcOptions.url, jdbcOptions.table, properties) - + JdbcUtils.saveTable(mode, parameters, data) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } @@ -71,8 +68,8 @@ class JdbcRelationProvider private def buildJDBCPartition(jdbcOptions: JDBCOptions): Array[Partition] = { if (jdbcOptions.partitionColumn != null && (jdbcOptions.lowerBound == null - || jdbcOptions.upperBound == null - || jdbcOptions.numPartitions == null)) { + || jdbcOptions.upperBound == null + || jdbcOptions.numPartitions == null)) { sys.error("Partitioning incompletely specified") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index cbd504603bbf..2ca04b77b10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -25,7 +25,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -370,13 +370,58 @@ object JdbcUtils extends Logging { } /** - * Saves the RDD to the database in a single transaction. + * Saves the content of the [[DataFrame]] to an external database table in a single transaction. + * + * @param mode Specifies the behavior when data or table already exists. + * @param parameters Specifies the JDBC database connection arguments + * @param df Specifies the dataframe */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties) { + def saveTable(mode: SaveMode, parameters: Map[String, String], df: DataFrame): Unit = { + val jdbcOptions = new JDBCOptions(parameters) + val url = parameters("url") + val table = parameters("dbtable") + + val properties = new Properties() + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + val conn = JdbcUtils.createConnectionFactory(url, properties)() + try { + var tableExists = JdbcUtils.tableExists(conn, url, table) + + if (tableExists) { + mode match { + case SaveMode.Ignore => return + case SaveMode.ErrorIfExists => sys.error(s"Table $table already exists.") + case SaveMode.Overwrite => + if (jdbcOptions.isTruncate && + JdbcUtils.isCascadingTruncateTable(url) == Option(false)) { + JdbcUtils.truncateTable(conn, table) + } else { + JdbcUtils.dropTable(conn, table) + tableExists = false + } + case SaveMode.Append => + } + } + + // Create the table if the table didn't exist. + if (!tableExists) { + val schema = JdbcUtils.schemaString(df, url) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $table ($schema) $createtblOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } + } finally { + conn.close() + } + val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType From 07e316823ed17e89c3df0aaccf3fbb958afcfe3a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 1 Sep 2016 14:42:03 -0700 Subject: [PATCH 5/5] clean code --- .../jdbc/JdbcRelationProvider.scala | 53 +++++++------------ 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d808717359cd..3163d0428238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.Partition -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider @@ -35,37 +34,6 @@ class JdbcRelationProvider sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - val parts = buildJDBCPartition(jdbcOptions) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) - } - - /** - * Save the DataFrame to the destination and return a relation with the given parameters based on - * the contents of the given DataFrame. - */ - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val jdbcOptions = new JDBCOptions(parameters) - val parts = buildJDBCPartition(jdbcOptions) - val properties = new Properties() - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JdbcUtils.saveTable(mode, parameters, data) - JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) - } - - /** - * Build Partitions based on the user-provided options: - * - "partitionColumn": the column used to partition - * - "lowerBound": the lower bound of partition column - * - "upperBound": the upper bound of the partition column - * - "numPartitions": the number of partitions - */ - private def buildJDBCPartition(jdbcOptions: JDBCOptions): Array[Partition] = { if (jdbcOptions.partitionColumn != null && (jdbcOptions.lowerBound == null || jdbcOptions.upperBound == null @@ -82,6 +50,23 @@ class JdbcRelationProvider jdbcOptions.upperBound.toLong, jdbcOptions.numPartitions.toInt) } - JDBCRelation.columnPartition(partitionInfo) + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /** + * Save the DataFrame to the destination and return a relation with the given parameters based on + * the contents of the given DataFrame. + */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + JdbcUtils.saveTable(mode, parameters, data) + createRelation(sqlContext, parameters) + } + }