Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 4 additions & 56 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,62 +414,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
assertNotPartitioned("jdbc")
assertNotBucketed("jdbc")

// to add required options like URL and dbtable
val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table)
val jdbcOptions = new JDBCOptions(params)
val jdbcUrl = jdbcOptions.url
val jdbcTable = jdbcOptions.table

val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)()

try {
var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable)

if (mode == SaveMode.Ignore && tableExists) {
return
}

if (mode == SaveMode.ErrorIfExists && tableExists) {
sys.error(s"Table $jdbcTable already exists.")
}

if (mode == SaveMode.Overwrite && tableExists) {
if (jdbcOptions.isTruncate &&
JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) {
JdbcUtils.truncateTable(conn, jdbcTable)
} else {
JdbcUtils.dropTable(conn, jdbcTable)
tableExists = false
}
}

// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, jdbcUrl)
// To allow certain options to append when create a new table, which can be
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val createtblOptions = jdbcOptions.createTableOptions
val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions"
val statement = conn.createStatement
try {
statement.executeUpdate(sql)
} finally {
statement.close()
}
}
} finally {
conn.close()
}

JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props)
// to add connectionProperties and required options like URL and dbtable
val params =
extraOptions.toMap ++ connectionProperties.asScala ++ Map("url" -> url, "dbtable" -> table)
JdbcUtils.saveTable(mode, params, df)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends RelationProvider with DataSourceRegister {
class JdbcRelationProvider
extends RelationProvider
with CreatableRelationProvider
with DataSourceRegister {

override def shortName(): String = "jdbc"

Expand Down Expand Up @@ -52,4 +55,18 @@ class JdbcRelationProvider extends RelationProvider with DataSourceRegister {
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession)
}

/**
* Save the DataFrame to the destination and return a relation with the given parameters based on
* the contents of the given DataFrame.
*/
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
JdbcUtils.saveTable(mode, parameters, data)
createRelation(sqlContext, parameters)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.util.Try
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -370,13 +370,58 @@ object JdbcUtils extends Logging {
}

/**
* Saves the RDD to the database in a single transaction.
* Saves the content of the [[DataFrame]] to an external database table in a single transaction.
*
* @param mode Specifies the behavior when data or table already exists.
* @param parameters Specifies the JDBC database connection arguments
* @param df Specifies the dataframe
*/
def saveTable(
df: DataFrame,
url: String,
table: String,
properties: Properties) {
def saveTable(mode: SaveMode, parameters: Map[String, String], df: DataFrame): Unit = {
val jdbcOptions = new JDBCOptions(parameters)
val url = parameters("url")
val table = parameters("dbtable")

val properties = new Properties()
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
val conn = JdbcUtils.createConnectionFactory(url, properties)()
try {
var tableExists = JdbcUtils.tableExists(conn, url, table)

if (tableExists) {
mode match {
case SaveMode.Ignore => return
case SaveMode.ErrorIfExists => sys.error(s"Table $table already exists.")
case SaveMode.Overwrite =>
if (jdbcOptions.isTruncate &&
JdbcUtils.isCascadingTruncateTable(url) == Option(false)) {
JdbcUtils.truncateTable(conn, table)
} else {
JdbcUtils.dropTable(conn, table)
tableExists = false
}
case SaveMode.Append =>
}
}

// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, url)
// To allow certain options to append when create a new table, which can be
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val createtblOptions = jdbcOptions.createTableOptions
val sql = s"CREATE TABLE $table ($schema) $createtblOptions"
val statement = conn.createStatement
try {
statement.executeUpdate(sql)
} finally {
statement.close()
}
}
} finally {
conn.close()
}

val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,18 @@ class JDBCSuite extends SparkFunSuite
}
}

test("load API") {
val dfUsingOption =
spark.read
.option("url", url)
.option("dbtable", "(SELECT * FROM TEST.PEOPLE)")
.option("user", "testUser")
.option("password", "testPass")
.format("jdbc")
.load()
assert(dfUsingOption.count == 3)
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(
spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,45 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}

test("save API") {
import scala.collection.JavaConverters._

val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)

df.write.format("jdbc")
.option("url", url1)
.option("dbtable", "TEST.TRUNCATETEST")
.options(properties.asScala)
.save()
df2.write.mode(SaveMode.Overwrite).format("jdbc")
.option("url", url1)
.option("dbtable", "TEST.TRUNCATETEST")
.options(properties.asScala)
.save()
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}

test("save API - error handling") {
import scala.collection.JavaConverters._
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)

var e = intercept[RuntimeException] {
df.write.format("jdbc")
.option("dbtable", "TEST.TRUNCATETEST")
.options(properties.asScala)
.save()
}.getMessage
assert(e.contains("Option 'url' not specified"))

e = intercept[org.h2.jdbc.JdbcSQLException] {
df.write.format("jdbc")
.option("dbtable", "TEST.TRUNCATETEST")
.option("url", url1)
.save()
}.getMessage
assert(e.contains("Wrong user name or password"))
}
}