diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ad23dae7c6b7..4bab58aff006 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1407,6 +1407,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing. + + + cascadeTruncate + + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. + + createTableOptions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index eea966d30948..574aed4958fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -157,6 +157,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -225,6 +227,7 @@ object JDBCOptions { val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b81737eda475..b908753cd2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -105,7 +105,12 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + } else { + dialect.getTruncateQuery(options.table) + } + statement.executeUpdate(truncateQuery) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f56b..3a3246a1b1d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38..d13c29ed46bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a11810c..f76c1fae562c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate * @return The SQL query to use for truncating a table */ @Since("2.3.0") def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE $table" + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f24460b..f4a6d0a4d2e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d0c..f8d2bc8e0f13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca2..6c17bd7ed9ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + // Teradata does not support cascading a truncation + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"DELETE FROM $table ALL" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0389273d6cdf..09facb9bef8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -861,19 +861,51 @@ class JDBCSuite extends QueryTest } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + + assert(postgres.getTruncateQuery(table) == postgresQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") {