From f2bf6ee494b0640a9b1f8aa929325fdea3256932 Mon Sep 17 00:00:00 2001 From: CK50 Date: Thu, 26 Nov 2015 04:22:43 -0800 Subject: [PATCH 1/7] Initial version --- .../datasources/jdbc/JdbcUtils.scala | 14 ++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 19 +++++++ .../sql/jdbc/ProgressCassandraDialect.scala | 51 +++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7375a5c09123..70c3fc43fc6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,15 +61,9 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + def insertStatement(conn: Connection, dialect:JdbcDialect, table: String, rddSchema: StructType): PreparedStatement = { + val sql = dialect.getInsertStatement(table, rddSchema) + conn.prepareStatement(sql) } /** @@ -127,7 +121,7 @@ object JdbcUtils extends Logging { var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema) try { var rowCount = 0 while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b3b2cb6178c5..310af8910fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -97,6 +97,24 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * Get the SQL statement for inserting a row int oa given table. Dialects can + * override this method to return a statement that works best in a particular database. + * @param table The name of the table. + * @param rddSchema The schema of the RDD. Some dialects require column names for the INSERT statement. + * @return The SQL INSERT statement to use for inserting a row into the table. + */ + def getInsertStatement(table: String, rddSchema: StructType): String = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + return sql.toString() + } + } /** @@ -140,6 +158,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) + registerDialect(ProgressCassandraDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala new file mode 100644 index 000000000000..532a18c75ba7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object ProgressCassandraDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") + + override def getInsertStatement(table: String, rddSchema: StructType): String = { + val sql = new StringBuilder(s"INSERT INTO $table ( ") + var fieldsLeft = rddSchema.fields.length + var i = 0 + // Build list of column names + while (fieldsLeft > 0) { + sql.append(rddSchema.fields(i).name) + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + i = i + 1 + } + sql.append(" ) VALUES ( ") + // Build values clause + fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + } + sql.append(" ) ") + return sql.toString() + } +} From d176b374ba34fcdbce76354257565a9cfcdcd102 Mon Sep 17 00:00:00 2001 From: CK50 Date: Fri, 27 Nov 2015 11:22:17 -0800 Subject: [PATCH 2/7] Style changes and renamed from ProgressCassandraDialect to CassandraDialect --- ...ogressCassandraDialect.scala => CassandraDialect.scala} | 6 ++++-- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/jdbc/{ProgressCassandraDialect.scala => CassandraDialect.scala} (88%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala rename to sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala index 532a18c75ba7..cff719f16629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala @@ -22,9 +22,11 @@ import java.sql.Types import org.apache.spark.sql.types._ -private case object ProgressCassandraDialect extends JdbcDialect { +private case object CassandraDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") + override def canHandle(url: String): Boolean = + url.startsWith("jdbc:datadirect:cassandra") || + url.startsWith("jdbc:weblogic:cassandra") override def getInsertStatement(table: String, rddSchema: StructType): String = { val sql = new StringBuilder(s"INSERT INTO $table ( ") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 310af8910fe1..d493d406f42a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -101,7 +101,8 @@ abstract class JdbcDialect extends Serializable { * Get the SQL statement for inserting a row int oa given table. Dialects can * override this method to return a statement that works best in a particular database. * @param table The name of the table. - * @param rddSchema The schema of the RDD. Some dialects require column names for the INSERT statement. + * @param rddSchema The schema of the RDD. Some dialects require column names for + * the INSERT statement. * @return The SQL INSERT statement to use for inserting a row into the table. */ def getInsertStatement(table: String, rddSchema: StructType): String = { @@ -113,7 +114,7 @@ abstract class JdbcDialect extends Serializable { fieldsLeft = fieldsLeft - 1 } return sql.toString() - } + } } @@ -158,7 +159,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) - registerDialect(ProgressCassandraDialect) + registerDialect(CassandraDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. From dd44460c58f031458c550ee36b211ea1b178dd28 Mon Sep 17 00:00:00 2001 From: CK50 Date: Fri, 27 Nov 2015 12:32:31 -0800 Subject: [PATCH 3/7] Clean style check run using Maven invocation --- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 5 ++++- .../scala/org/apache/spark/sql/jdbc/CassandraDialect.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 70c3fc43fc6f..2a40dd67159e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,7 +61,10 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, dialect:JdbcDialect, table: String, rddSchema: StructType): PreparedStatement = { + def insertStatement(conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType): PreparedStatement = { val sql = dialect.getInsertStatement(table, rddSchema) conn.prepareStatement(sql) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala index cff719f16629..de4a111a0868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ private case object CassandraDialect extends JdbcDialect { override def canHandle(url: String): Boolean = - url.startsWith("jdbc:datadirect:cassandra") || + url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") override def getInsertStatement(table: String, rddSchema: StructType): String = { From 024475ab0038fa3576608786dec006120df22fcb Mon Sep 17 00:00:00 2001 From: CK50 Date: Thu, 26 Nov 2015 04:22:43 -0800 Subject: [PATCH 4/7] Initial version --- .../datasources/jdbc/JdbcUtils.scala | 14 ++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 19 +++++++ .../sql/jdbc/ProgressCassandraDialect.scala | 51 +++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 252f1cfd5d9c..be83533c0e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -62,15 +62,9 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + def insertStatement(conn: Connection, dialect:JdbcDialect, table: String, rddSchema: StructType): PreparedStatement = { + val sql = dialect.getInsertStatement(table, rddSchema) + conn.prepareStatement(sql) } /** @@ -139,7 +133,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema) try { var rowCount = 0 while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 13db141f27db..166e88e85bbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -108,6 +108,24 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Get the SQL statement for inserting a row into a given table. Dialects can + * override this method to return a statement that works best in a particular database. + * @param table The name of the table. + * @param rddSchema The schema of the RDD. Some dialects require column names for the INSERT statement. + * @return The SQL INSERT statement to use for inserting a row into the table. + */ + def getInsertStatement(table: String, rddSchema: StructType): String = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + return sql.toString() + } + } /** @@ -151,6 +169,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) + registerDialect(ProgressCassandraDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala new file mode 100644 index 000000000000..532a18c75ba7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object ProgressCassandraDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") + + override def getInsertStatement(table: String, rddSchema: StructType): String = { + val sql = new StringBuilder(s"INSERT INTO $table ( ") + var fieldsLeft = rddSchema.fields.length + var i = 0 + // Build list of column names + while (fieldsLeft > 0) { + sql.append(rddSchema.fields(i).name) + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + i = i + 1 + } + sql.append(" ) VALUES ( ") + // Build values clause + fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") + fieldsLeft = fieldsLeft - 1 + } + sql.append(" ) ") + return sql.toString() + } +} From 95b25388d22dd0865998d08e0322b7f0823ba593 Mon Sep 17 00:00:00 2001 From: CK50 Date: Fri, 27 Nov 2015 11:22:17 -0800 Subject: [PATCH 5/7] Style changes and renamed from ProgressCassandraDialect to CassandraDialect --- ...ogressCassandraDialect.scala => CassandraDialect.scala} | 6 ++++-- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/jdbc/{ProgressCassandraDialect.scala => CassandraDialect.scala} (88%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala rename to sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala index 532a18c75ba7..cff719f16629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/ProgressCassandraDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala @@ -22,9 +22,11 @@ import java.sql.Types import org.apache.spark.sql.types._ -private case object ProgressCassandraDialect extends JdbcDialect { +private case object CassandraDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") + override def canHandle(url: String): Boolean = + url.startsWith("jdbc:datadirect:cassandra") || + url.startsWith("jdbc:weblogic:cassandra") override def getInsertStatement(table: String, rddSchema: StructType): String = { val sql = new StringBuilder(s"INSERT INTO $table ( ") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 166e88e85bbb..4e94a4cbfdcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -112,7 +112,8 @@ abstract class JdbcDialect extends Serializable { * Get the SQL statement for inserting a row into a given table. Dialects can * override this method to return a statement that works best in a particular database. * @param table The name of the table. - * @param rddSchema The schema of the RDD. Some dialects require column names for the INSERT statement. + * @param rddSchema The schema of the RDD. Some dialects require column names for + * the INSERT statement. * @return The SQL INSERT statement to use for inserting a row into the table. */ def getInsertStatement(table: String, rddSchema: StructType): String = { @@ -124,7 +125,7 @@ abstract class JdbcDialect extends Serializable { fieldsLeft = fieldsLeft - 1 } return sql.toString() - } + } } @@ -169,7 +170,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) - registerDialect(ProgressCassandraDialect) + registerDialect(CassandraDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. From be3e31c512ba7c1e7e159014431513cc6a959825 Mon Sep 17 00:00:00 2001 From: CK50 Date: Fri, 27 Nov 2015 12:32:31 -0800 Subject: [PATCH 6/7] Clean style check run using Maven invocation --- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 5 ++++- .../scala/org/apache/spark/sql/jdbc/CassandraDialect.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index be83533c0e35..23977df9d3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -62,7 +62,10 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, dialect:JdbcDialect, table: String, rddSchema: StructType): PreparedStatement = { + def insertStatement(conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType): PreparedStatement = { val sql = dialect.getInsertStatement(table, rddSchema) conn.prepareStatement(sql) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala index cff719f16629..de4a111a0868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ private case object CassandraDialect extends JdbcDialect { override def canHandle(url: String): Boolean = - url.startsWith("jdbc:datadirect:cassandra") || + url.startsWith("jdbc:datadirect:cassandra") || url.startsWith("jdbc:weblogic:cassandra") override def getInsertStatement(table: String, rddSchema: StructType): String = { From cac053696fef6fe326c9855ab57f13899a9e7e82 Mon Sep 17 00:00:00 2001 From: CK50 Date: Tue, 15 Dec 2015 04:16:06 -0800 Subject: [PATCH 7/7] Added support for columnMappings to fix SPARK-012010 --- python/pyspark/sql/readwriter.py | 11 ++++- .../apache/spark/sql/DataFrameWriter.scala | 42 ++++++++++++++++++- .../datasources/jdbc/JdbcUtils.scala | 16 ++++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 38 ++++++++++------- .../spark/sql/jdbc/JDBCWriteSuite.scala | 10 +++++ 5 files changed, 93 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a182..b5ab3c661dfa 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -494,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties=None): + def jdbc(self, url, table, mode=None, properties=None, columnMapping=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -511,13 +511,20 @@ def jdbc(self, url, table, mode=None, properties=None): :param properties: JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least a "user" and "password" property should be included. + :param columnMapping: optional column name mapping from DF field names to + JDBC table column names. """ if properties is None: properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.mode(mode).jdbc(url, table, jprop) + if columnMapping is None: + columnMapping = dict() + jcolumnMapping = JavaClass("java.util.HashMap", self._sqlContext._sc._gateway._gateway_client)() + for k in columnMapping: + jcolumnMapping.put(k, columnMapping[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop, jcolumnMapping) def _test(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb7822..4b04fdd90a1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -253,6 +253,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** + * (scala-specific) * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the * save mode, specified by the `mode` function (default to throwing an exception). @@ -265,10 +266,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. + * @param columnMapping Maps DataFrame column names to target table column names. + * This parameter can be omitted if the target table has/will be + * created in this method and therefore the target table structure + * matches the DF structure. + * This parameter is stongly recommended, if target table already + * exists and has been created outside of this method. + * If omitted, the SQL insert statement will not include column names, + * which means that the field ordering of the DataFrame must match + * the target table column ordering. * * @since 1.4.0 */ - def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: scala.collection.immutable.Map[String, String]): Unit = { val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -303,7 +316,32 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JdbcUtils.saveTable(df, url, table, props) + JdbcUtils.saveTable(df, url, table, props, columnMapping) + } + + /** + * (java-specific) version of jdbc method + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: java.util.Map[String, String]): Unit = { + // Convert java Map into immutable scala Map + var sColumnMapping: scala.collection.immutable.Map[String, String] = null + if (columnMapping!=null) { + sColumnMapping = collection.immutable.Map(columnMapping.asScala.toList: _*) + } + jdbc( url, table, connectionProperties, sColumnMapping ) + } + + /** + * legacy three parameter version of jdbc method + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties): Unit = { + val columnMapping: scala.collection.immutable.Map[String, String] = null + jdbc( url, table, connectionProperties, columnMapping ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 23977df9d3e9..febe3e73bfef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,12 +61,15 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. + * If a columnMapping is provided, it will be used to translate rdd + * column names into table column names. */ def insertStatement(conn: Connection, dialect: JdbcDialect, table: String, - rddSchema: StructType): PreparedStatement = { - val sql = dialect.getInsertStatement(table, rddSchema) + rddSchema: StructType, + columnMapping: Map[String, String]): PreparedStatement = { + val sql = dialect.getInsertStatement(table, rddSchema, columnMapping) conn.prepareStatement(sql) } @@ -119,6 +122,7 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], + columnMapping: Map[String, String] = null, batchSize: Int, dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() @@ -136,7 +140,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, dialect, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema, columnMapping) try { var rowCount = 0 while (iterator.hasNext) { @@ -231,7 +235,8 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties = new Properties(), + columnMapping: Map[String, String] = null) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType @@ -242,7 +247,8 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, + columnMapping, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 4e94a4cbfdcd..57013c1b2e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -109,22 +109,31 @@ abstract class JdbcDialect extends Serializable { } /** - * Get the SQL statement for inserting a row into a given table. Dialects can - * override this method to return a statement that works best in a particular database. - * @param table The name of the table. - * @param rddSchema The schema of the RDD. Some dialects require column names for - * the INSERT statement. - * @return The SQL INSERT statement to use for inserting a row into the table. + * Get the SQL statement that should be used to insert new records into the table. + * Dialects can override this method to return a statement that works best in a particular + * database. + * @param table The name of the table. + * @param rddSchema The schema of DataFrame to be inserted + * @param columnMapping An optional mapping from DataFrame field names to database column + * names + * @return The SQL statement to use for inserting into the table. */ - def getInsertStatement(table: String, rddSchema: StructType): String = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 + def getInsertStatement(table: String, + rddSchema: StructType, + columnMapping: Map[String, String] = null): String = { + if (columnMapping == null) { + return rddSchema.fields.map(field => "?") + .mkString( s"INSERT INTO $table VALUES (", ", ", " ) ") + } else { + return rddSchema.fields.map( + field => columnMapping.get(field.name) match { + case Some(name) => name + case None => s"" + } + ).mkString( s"INSERT INTO $table ( ", ", ", " ) " ) + + rddSchema.fields.map(field => "?").mkString( "VALUES ( ", ", ", " )" ) } - return sql.toString() } } @@ -170,7 +179,6 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) - registerDialect(CassandraDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133..bf28dacfb07b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -96,6 +96,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } + test("Basic CREATE with columnMapping") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val columnMapping = Map("name" -> "name", "id" -> "id") + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties, columnMapping) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + } + test("CREATE with overwrite") { val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)