diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a182..b5ab3c661dfa 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -494,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties=None): + def jdbc(self, url, table, mode=None, properties=None, columnMapping=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -511,13 +511,20 @@ def jdbc(self, url, table, mode=None, properties=None): :param properties: JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least a "user" and "password" property should be included. + :param columnMapping: optional column name mapping from DF field names to + JDBC table column names. """ if properties is None: properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.mode(mode).jdbc(url, table, jprop) + if columnMapping is None: + columnMapping = dict() + jcolumnMapping = JavaClass("java.util.HashMap", self._sqlContext._sc._gateway._gateway_client)() + for k in columnMapping: + jcolumnMapping.put(k, columnMapping[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop, jcolumnMapping) def _test(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb7822..4b04fdd90a1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -253,6 +253,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** + * (scala-specific) * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the * save mode, specified by the `mode` function (default to throwing an exception). @@ -265,10 +266,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. + * @param columnMapping Maps DataFrame column names to target table column names. + * This parameter can be omitted if the target table has/will be + * created in this method and therefore the target table structure + * matches the DF structure. + * This parameter is stongly recommended, if target table already + * exists and has been created outside of this method. + * If omitted, the SQL insert statement will not include column names, + * which means that the field ordering of the DataFrame must match + * the target table column ordering. * * @since 1.4.0 */ - def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: scala.collection.immutable.Map[String, String]): Unit = { val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -303,7 +316,32 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JdbcUtils.saveTable(df, url, table, props) + JdbcUtils.saveTable(df, url, table, props, columnMapping) + } + + /** + * (java-specific) version of jdbc method + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: java.util.Map[String, String]): Unit = { + // Convert java Map into immutable scala Map + var sColumnMapping: scala.collection.immutable.Map[String, String] = null + if (columnMapping!=null) { + sColumnMapping = collection.immutable.Map(columnMapping.asScala.toList: _*) + } + jdbc( url, table, connectionProperties, sColumnMapping ) + } + + /** + * legacy three parameter version of jdbc method + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties): Unit = { + val columnMapping: scala.collection.immutable.Map[String, String] = null + jdbc( url, table, connectionProperties, columnMapping ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 252f1cfd5d9c..febe3e73bfef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,16 +61,16 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. + * If a columnMapping is provided, it will be used to translate rdd + * column names into table column names. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + def insertStatement(conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType, + columnMapping: Map[String, String]): PreparedStatement = { + val sql = dialect.getInsertStatement(table, rddSchema, columnMapping) + conn.prepareStatement(sql) } /** @@ -122,6 +122,7 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], + columnMapping: Map[String, String] = null, batchSize: Int, dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() @@ -139,7 +140,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema, columnMapping) try { var rowCount = 0 while (iterator.hasNext) { @@ -234,7 +235,8 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties = new Properties(), + columnMapping: Map[String, String] = null) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType @@ -245,7 +247,8 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, + columnMapping, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala new file mode 100644 index 000000000000..de4a111a0868 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object CassandraDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = + url.startsWith("jdbc:datadirect:cassandra") || + url.startsWith("jdbc:weblogic:cassandra") + + override def getInsertStatement(table: String, rddSchema: StructType): String = { + val sql = new StringBuilder(s"INSERT INTO $table ( ") + var fieldsLeft = rddSchema.fields.length + var i = 0 + // Build list of column names + while (fieldsLeft > 0) { + sql.append(rddSchema.fields(i).name) + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + i = i + 1 + } + sql.append(" ) VALUES ( ") + // Build values clause + fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + } + sql.append(" ) ") + return sql.toString() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 13db141f27db..57013c1b2e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -108,6 +108,34 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Get the SQL statement that should be used to insert new records into the table. + * Dialects can override this method to return a statement that works best in a particular + * database. + * @param table The name of the table. + * @param rddSchema The schema of DataFrame to be inserted + * @param columnMapping An optional mapping from DataFrame field names to database column + * names + * @return The SQL statement to use for inserting into the table. + */ + def getInsertStatement(table: String, + rddSchema: StructType, + columnMapping: Map[String, String] = null): String = { + if (columnMapping == null) { + return rddSchema.fields.map(field => "?") + .mkString( s"INSERT INTO $table VALUES (", ", ", " ) ") + } else { + return rddSchema.fields.map( + field => columnMapping.get(field.name) match { + case Some(name) => name + case None => s"" + } + ).mkString( s"INSERT INTO $table ( ", ", ", " ) " ) + + rddSchema.fields.map(field => "?").mkString( "VALUES ( ", ", ", " )" ) + } + } + } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133..bf28dacfb07b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -96,6 +96,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } + test("Basic CREATE with columnMapping") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val columnMapping = Map("name" -> "name", "id" -> "id") + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties, columnMapping) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + } + test("CREATE with overwrite") { val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)