diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index a9049a60f25e..3898f0e74a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -414,62 +414,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - - // to add required options like URL and dbtable - val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) - val jdbcOptions = new JDBCOptions(params) - val jdbcUrl = jdbcOptions.url - val jdbcTable = jdbcOptions.table - - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } - // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() - - try { - var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) - - if (mode == SaveMode.Ignore && tableExists) { - return - } - - if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $jdbcTable already exists.") - } - - if (mode == SaveMode.Overwrite && tableExists) { - if (jdbcOptions.isTruncate && - JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { - JdbcUtils.truncateTable(conn, jdbcTable) - } else { - JdbcUtils.dropTable(conn, jdbcTable) - tableExists = false - } - } - - // Create the table if the table didn't exist. - if (!tableExists) { - val schema = JdbcUtils.schemaString(df, jdbcUrl) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() - } - } - } finally { - conn.close() - } - - JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) + // to add connectionProperties and required options like URL and dbtable + val params = + extraOptions.toMap ++ connectionProperties.asScala ++ Map("url" -> url, "dbtable" -> table) + JdbcUtils.saveTable(mode, params, df) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 106ed1d44010..3163d0428238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} -class JdbcRelationProvider extends RelationProvider with DataSourceRegister { +class JdbcRelationProvider + extends RelationProvider + with CreatableRelationProvider + with DataSourceRegister { override def shortName(): String = "jdbc" @@ -52,4 +55,18 @@ class JdbcRelationProvider extends RelationProvider with DataSourceRegister { parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /** + * Save the DataFrame to the destination and return a relation with the given parameters based on + * the contents of the given DataFrame. + */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + JdbcUtils.saveTable(mode, parameters, data) + createRelation(sqlContext, parameters) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index cbd504603bbf..2ca04b77b10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -25,7 +25,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -370,13 +370,58 @@ object JdbcUtils extends Logging { } /** - * Saves the RDD to the database in a single transaction. + * Saves the content of the [[DataFrame]] to an external database table in a single transaction. + * + * @param mode Specifies the behavior when data or table already exists. + * @param parameters Specifies the JDBC database connection arguments + * @param df Specifies the dataframe */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties) { + def saveTable(mode: SaveMode, parameters: Map[String, String], df: DataFrame): Unit = { + val jdbcOptions = new JDBCOptions(parameters) + val url = parameters("url") + val table = parameters("dbtable") + + val properties = new Properties() + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + val conn = JdbcUtils.createConnectionFactory(url, properties)() + try { + var tableExists = JdbcUtils.tableExists(conn, url, table) + + if (tableExists) { + mode match { + case SaveMode.Ignore => return + case SaveMode.ErrorIfExists => sys.error(s"Table $table already exists.") + case SaveMode.Overwrite => + if (jdbcOptions.isTruncate && + JdbcUtils.isCascadingTruncateTable(url) == Option(false)) { + JdbcUtils.truncateTable(conn, table) + } else { + JdbcUtils.dropTable(conn, table) + tableExists = false + } + case SaveMode.Append => + } + } + + // Create the table if the table didn't exist. + if (!tableExists) { + val schema = JdbcUtils.schemaString(df, url) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $table ($schema) $createtblOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } + } finally { + conn.close() + } + val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2d8ee338a980..17913048debc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -370,6 +370,18 @@ class JDBCSuite extends SparkFunSuite } } + test("load API") { + val dfUsingOption = + spark.read + .option("url", url) + .option("dbtable", "(SELECT * FROM TEST.PEOPLE)") + .option("user", "testUser") + .option("password", "testPass") + .format("jdbc") + .load() + assert(dfUsingOption.count == 3) + } + test("Partitioning via JDBCPartitioningInfo API") { assert( spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ff3309874f2e..dd10e11a8b65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -208,4 +208,45 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save API") { + import scala.collection.JavaConverters._ + + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + } + + test("save API - error handling") { + import scala.collection.JavaConverters._ + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + var e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.TRUNCATETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' not specified")) + + e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.TRUNCATETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } }