From 541645246deb399dfba9c2f86f4199c3dec81c95 Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Fri, 22 Dec 2017 16:58:28 +0100 Subject: [PATCH 1/4] [SPARK-22880][SQL] Add cascadeTruncate option to JDBC datasource This commit adds the cascadeTruncate option to the JDBC datasource API, for databases that support this functionality (PostgreSQL and Oracle at the moment). This allows for applying a cascading truncate that affects tables that have foreign key constraints on the table being truncated. --- docs/sql-programming-guide.md | 7 +++++ .../datasources/jdbc/JDBCOptions.scala | 3 +++ .../datasources/jdbc/JdbcUtils.scala | 2 +- .../spark/sql/jdbc/AggregatedDialect.scala | 4 +-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 6 +++-- .../apache/spark/sql/jdbc/OracleDialect.scala | 10 +++++++ .../spark/sql/jdbc/PostgresDialect.scala | 7 ++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 27 +++++++++++++++++++ 8 files changed, 58 insertions(+), 8 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 91e43678481d..055685127e02 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1372,6 +1372,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing. + + + cascadeTruncate + + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE. This will affect other tables, and thus should be used with case. This option applies only to writing. + + createTableOptions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d..7e34377d881d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,6 +119,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate = parameters.getOrElse(JDBC_CASCADE_TRUNCATE, "false").toBoolean // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -162,6 +164,7 @@ object JDBCOptions { val JDBC_NUM_PARTITIONS = newOption("numPartitions") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1..ba9100bf8fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -102,7 +102,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + statement.executeUpdate(dialect.getTruncateQuery(options.table, options.isCascadeTruncate)) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f56b..a5a75c4ed018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + override def getTruncateQuery(table: String, cascade: Boolean = false): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a11810c..96bf670e0166 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,11 +121,12 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate + * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false * @return The SQL query to use for truncating a table */ @Since("2.3.0") - def getTruncateQuery(table: String): String = { + def getTruncateQuery(table: String, cascade: Boolean = false): String = { s"TRUNCATE TABLE $table" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f24460b..81eeb327848f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -94,5 +94,15 @@ private case object OracleDialect extends JdbcDialect { case _ => value } + /** + * The SQL query used to truncate a table. + * @param table The JDBCOptions. + * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery(table: String, cascade: Boolean = false): String = { + s"TRUNCATE TABLE $table${if (cascade) " CASCADE" else ""}" + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d0c..0acd05d79934 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -89,11 +89,12 @@ private object PostgresDialect extends JdbcDialect { * The SQL query used to truncate a table. For Postgres, the default behaviour is to * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. + * @param table The table to truncate + * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false * @return The SQL query to use for truncating a table */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + override def getTruncateQuery(table: String, cascade: Boolean = false): String = { + s"TRUNCATE TABLE ONLY $table${if (cascade) " CASCADE" else ""}" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a69..55c80b29d653 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -860,14 +860,41 @@ class JDBCSuite extends SparkFunSuite val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" + assert(MySQL.getTruncateQuery(table) == defaultQuery) assert(Postgres.getTruncateQuery(table) == postgresQuery) assert(db2.getTruncateQuery(table) == defaultQuery) assert(h2.getTruncateQuery(table) == defaultQuery) assert(derby.getTruncateQuery(table) == defaultQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + + assert(MySQL.getTruncateQuery(table, true) == defaultQuery) + assert(Postgres.getTruncateQuery(table, true) == postgresQuery) + assert(db2.getTruncateQuery(table, true) == defaultQuery) + assert(h2.getTruncateQuery(table, true) == defaultQuery) + assert(derby.getTruncateQuery(table, true) == defaultQuery) + assert(oracle.getTruncateQuery(table, true) == oracleQuery) } test("Test DataFrame.where for Date and Timestamp") { From c262993d6cd7bb2a10069d4701693850dd4c9389 Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Wed, 3 Jan 2018 18:00:46 +0100 Subject: [PATCH 2/4] Change cascade parameter to option This allows us to use the isCascadingTruncateTable function as default for each dialect --- docs/sql-programming-guide.md | 2 +- .../datasources/jdbc/JDBCOptions.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 7 +- .../spark/sql/jdbc/AggregatedDialect.scala | 11 +++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 2 + .../apache/spark/sql/jdbc/JdbcDialects.scala | 19 +++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 23 ++++--- .../spark/sql/jdbc/PostgresDialect.scala | 30 ++++++--- .../spark/sql/jdbc/TeradataDialect.scala | 15 +++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 65 ++++++++++--------- 10 files changed, 121 insertions(+), 55 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 055685127e02..1a8705c47a6a 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1376,7 +1376,7 @@ the following case-insensitive options: cascadeTruncate - This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE. This will affect other tables, and thus should be used with case. This option applies only to writing. + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7e34377d881d..8aa9ae642728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -120,7 +120,7 @@ class JDBCOptions( // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean - val isCascadeTruncate = parameters.getOrElse(JDBC_CASCADE_TRUNCATE, "false").toBoolean + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index ba9100bf8fc9..be3b24b554d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -102,7 +102,12 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(dialect.getTruncateQuery(options.table, options.isCascadeTruncate)) + if (options.isCascadeTruncate.isDefined) { + statement.executeUpdate(dialect.getTruncateQuery(options.table, + options.isCascadeTruncate)) + } else { + statement.executeUpdate(dialect.getTruncateQuery(options.table)) + } } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index a5a75c4ed018..3a3246a1b1d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String, cascade: Boolean = false): String = { + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38..d13c29ed46bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 96bf670e0166..48c42fbfe1b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -122,14 +122,29 @@ abstract class JdbcDialect extends Serializable { * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. * @param table The table to truncate - * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false * @return The SQL query to use for truncating a table */ @Since("2.3.0") - def getTruncateQuery(table: String, cascade: Boolean = false): String = { + def getTruncateQuery(table: String): String = { + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { s"TRUNCATE TABLE $table" } + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 81eeb327848f..37e1e119cb37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -94,15 +94,22 @@ private case object OracleDialect extends JdbcDialect { case _ => value } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. - * @param table The JDBCOptions. - * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String, cascade: Boolean = false): String = { - s"TRUNCATE TABLE $table${if (cascade) " CASCADE" else ""}" + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 0acd05d79934..f8d2bc8e0f13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,16 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The table to truncate - * @param cascade (OPTIONAL) Whether or not to cascade the truncation. Default: false - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String, cascade: Boolean = false): String = { - s"TRUNCATE TABLE ONLY $table${if (cascade) " CASCADE" else ""}" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -111,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca2..aa97633c0d7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,19 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Ignored for Teradata as it is unsupported + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 55c80b29d653..0636f57debb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -684,17 +684,17 @@ class JDBCSuite extends SparkFunSuite } test("quote column names by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - val Derby = JdbcDialects.get("jdbc:derby:db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val derby = JdbcDialects.get("jdbc:derby:db") val columns = Seq("abc", "key") - val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) - val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) - val DerbyColumns = columns.map(Derby.quoteIdentifier(_)) - assert(MySQLColumns === Seq("`abc`", "`key`")) - assert(PostgresColumns === Seq(""""abc"""", """"key"""")) - assert(DerbyColumns === Seq(""""abc"""", """"key"""")) + val mySQLColumns = columns.map(mysql.quoteIdentifier(_)) + val postgresColumns = columns.map(postgres.quoteIdentifier(_)) + val derbyColumns = columns.map(derby.quoteIdentifier(_)) + assert(mySQLColumns === Seq("`abc`", "`key`")) + assert(postgresColumns === Seq(""""abc"""", """"key"""")) + assert(derbyColumns === Seq(""""abc"""", """"key"""")) } test("compile filters") { @@ -805,13 +805,13 @@ class JDBCSuite extends SparkFunSuite } test("PostgresDialect type mapping") { - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) - assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") - assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + assert(postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) + assert(postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") + assert(postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") val errMsg = intercept[IllegalArgumentException] { - Postgres.getJDBCType(ByteType) + postgres.getJDBCType(ByteType) } assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") } @@ -839,35 +839,36 @@ class JDBCSuite extends SparkFunSuite } test("table exists query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val table = "weblogs" val defaultQuery = s"SELECT * FROM $table WHERE 1=0" val limitQuery = s"SELECT 1 FROM $table LIMIT 1" - assert(MySQL.getTableExistsQuery(table) == limitQuery) - assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(mysql.getTableExistsQuery(table) == limitQuery) + assert(postgres.getTableExistsQuery(table) == limitQuery) assert(db2.getTableExistsQuery(table) == defaultQuery) assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) + assert(mysql.getTruncateQuery(table) == defaultQuery) + assert(postgres.getTruncateQuery(table) == postgresQuery) assert(db2.getTruncateQuery(table) == defaultQuery) assert(h2.getTruncateQuery(table) == defaultQuery) assert(derby.getTruncateQuery(table) == defaultQuery) @@ -877,24 +878,26 @@ class JDBCSuite extends SparkFunSuite test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { // cascade in a truncate should only be applied for databases that support this, // even if the parameter is passed. - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" val oracleQuery = s"TRUNCATE TABLE $table CASCADE" - assert(MySQL.getTruncateQuery(table, true) == defaultQuery) - assert(Postgres.getTruncateQuery(table, true) == postgresQuery) - assert(db2.getTruncateQuery(table, true) == defaultQuery) - assert(h2.getTruncateQuery(table, true) == defaultQuery) - assert(derby.getTruncateQuery(table, true) == defaultQuery) - assert(oracle.getTruncateQuery(table, true) == oracleQuery) + assert(mysql.getTruncateQuery(table, Some(true)) == defaultQuery) + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(db2.getTruncateQuery(table, Some(true)) == defaultQuery) + assert(h2.getTruncateQuery(table, Some(true)) == defaultQuery) + assert(derby.getTruncateQuery(table, Some(true)) == defaultQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == defaultQuery) } test("Test DataFrame.where for Date and Timestamp") { From 7e0ff071bc366b75b1f77e66a7b7e24a6d5790c9 Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Wed, 21 Feb 2018 10:54:12 +0100 Subject: [PATCH 3/4] Use correct truncation syntax for Teradata Also update the tests accordingly --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 1 - .../apache/spark/sql/jdbc/OracleDialect.scala | 1 - .../spark/sql/jdbc/TeradataDialect.scala | 9 ++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 60 ++++++++++--------- 4 files changed, 37 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 48c42fbfe1b0..9c7258753920 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -144,7 +144,6 @@ abstract class JdbcDialect extends Serializable { s"TRUNCATE TABLE $table" } - /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 37e1e119cb37..f4a6d0a4d2e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -111,5 +111,4 @@ private case object OracleDialect extends JdbcDialect { case _ => s"TRUNCATE TABLE $table" } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index aa97633c0d7c..6c17bd7ed9ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -32,18 +32,21 @@ private case object TeradataDialect extends JdbcDialect { case _ => None } + // Teradata does not support cascading a truncation override def isCascadingTruncateTable(): Option[Boolean] = Some(false) /** - * The SQL query used to truncate a table. + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. * @param table The table to truncate. * @param cascade Whether or not to cascade the truncation. Default value is the - * value of isCascadingTruncateTable(). Ignored for Teradata as it is unsupported + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) * @return The SQL query to use for truncating a table */ override def getTruncateQuery( table: String, cascade: Option[Boolean] = isCascadingTruncateTable): String = { - s"TRUNCATE TABLE $table" + s"DELETE FROM $table ALL" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0636f57debb1..72ecc11f57db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -684,17 +684,17 @@ class JDBCSuite extends SparkFunSuite } test("quote column names by jdbc dialect") { - val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - val derby = JdbcDialects.get("jdbc:derby:db") + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val Derby = JdbcDialects.get("jdbc:derby:db") val columns = Seq("abc", "key") - val mySQLColumns = columns.map(mysql.quoteIdentifier(_)) - val postgresColumns = columns.map(postgres.quoteIdentifier(_)) - val derbyColumns = columns.map(derby.quoteIdentifier(_)) - assert(mySQLColumns === Seq("`abc`", "`key`")) - assert(postgresColumns === Seq(""""abc"""", """"key"""")) - assert(derbyColumns === Seq(""""abc"""", """"key"""")) + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + val DerbyColumns = columns.map(Derby.quoteIdentifier(_)) + assert(MySQLColumns === Seq("`abc`", "`key`")) + assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } test("compile filters") { @@ -805,13 +805,13 @@ class JDBCSuite extends SparkFunSuite } test("PostgresDialect type mapping") { - val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - assert(postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) - assert(postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) - assert(postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") - assert(postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") + assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") val errMsg = intercept[IllegalArgumentException] { - postgres.getJDBCType(ByteType) + Postgres.getJDBCType(ByteType) } assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") } @@ -839,16 +839,16 @@ class JDBCSuite extends SparkFunSuite } test("table exists query by jdbc dialect") { - val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val table = "weblogs" val defaultQuery = s"SELECT * FROM $table WHERE 1=0" val limitQuery = s"SELECT 1 FROM $table LIMIT 1" - assert(mysql.getTableExistsQuery(table) == limitQuery) - assert(postgres.getTableExistsQuery(table) == limitQuery) + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) assert(db2.getTableExistsQuery(table) == defaultQuery) assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) @@ -861,18 +861,20 @@ class JDBCSuite extends SparkFunSuite val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") - + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } - assert(mysql.getTruncateQuery(table) == defaultQuery) assert(postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) } test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { @@ -890,14 +892,14 @@ class JDBCSuite extends SparkFunSuite val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" - assert(mysql.getTruncateQuery(table, Some(true)) == defaultQuery) + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) - assert(db2.getTruncateQuery(table, Some(true)) == defaultQuery) - assert(h2.getTruncateQuery(table, Some(true)) == defaultQuery) - assert(derby.getTruncateQuery(table, Some(true)) == defaultQuery) assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) - assert(teradata.getTruncateQuery(table, Some(true)) == defaultQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") { From a365f79b2f29326621a4cd0177780e66c56eaceb Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Fri, 20 Jul 2018 09:27:12 +0200 Subject: [PATCH 4/4] Minor fixes --- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 8 ++++---- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 45c72d39d4c8..b908753cd2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -105,12 +105,12 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - if (options.isCascadeTruncate.isDefined) { - statement.executeUpdate(dialect.getTruncateQuery(options.table, - options.isCascadeTruncate)) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) } else { - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + dialect.getTruncateQuery(options.table) } + statement.executeUpdate(truncateQuery) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9c7258753920..f76c1fae562c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -141,7 +141,7 @@ abstract class JdbcDialect extends Serializable { def getTruncateQuery( table: String, cascade: Option[Boolean] = isCascadingTruncateTable): String = { - s"TRUNCATE TABLE $table" + s"TRUNCATE TABLE $table" } /**