diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fb9e11df1883..f2b84810175e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -78,7 +78,8 @@ object JdbcUtils extends Logging with SQLConfHelper { * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { - executeStatement(conn, options, s"DROP TABLE $table") + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, dialect.dropTable(table)) } /** @@ -114,22 +115,19 @@ object JdbcUtils extends Logging with SQLConfHelper { isCaseSensitive: Boolean, dialect: JdbcDialect): String = { val columns = if (tableSchema.isEmpty) { - rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") + rddSchema.fields } else { // The generated insert statement needs to follow rddSchema's column sequence and // tableSchema's column names. When appending data into some case-sensitive DBMSs like // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of // RDD column names for user convenience. - val tableColumnNames = tableSchema.get.fieldNames rddSchema.fields.map { col => - val normalizedName = tableColumnNames.find(f => conf.resolver(f, col.name)).getOrElse { + tableSchema.get.find(f => conf.resolver(f.name, col.name)).getOrElse { throw QueryCompilationErrors.columnNotFoundInSchemaError(col, tableSchema) } - dialect.quoteIdentifier(normalizedName) - }.mkString(",") + } } - val placeholders = rddSchema.fields.map(_ => "?").mkString(",") - s"INSERT INTO $table ($columns) VALUES ($placeholders)" + dialect.insertIntoTable(table, columns) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 22625523a042..37c378c294c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -193,6 +193,24 @@ abstract class JdbcDialect extends Serializable with Logging { statement.executeUpdate(s"CREATE TABLE $tableName ($strSchema) $createTableOptions") } + /** + * Returns an Insert SQL statement template for inserting a row into the target table via JDBC + * conn. Use "?" as placeholder for each value to be inserted. + * E.g. `INSERT INTO t ("name", "age", "gender") VALUES (?, ?, ?)` + * + * @param table The name of the table. + * @param fields The fields of the row that will be inserted. + * @return The SQL query to use for insert data into table. + */ + @Since("4.0.0") + def insertIntoTable( + table: String, + fields: Array[StructField]): String = { + val placeholders = fields.map(_ => "?").mkString(",") + val columns = fields.map(x => quoteIdentifier(x.name)).mkString(",") + s"INSERT INTO $table ($columns) VALUES ($placeholders)" + } + /** * Get the SQL query that should be used to find if the given table exists. Dialects can * override this method to return a query that works best in a particular database. @@ -542,6 +560,17 @@ abstract class JdbcDialect extends Serializable with Logging { } } + /** + * Build a SQL statement to drop the given table. + * + * @param table the table name + * @return The SQL statement to use for drop the table. + */ + @Since("4.0.0") + def dropTable(table: String): String = { + s"DROP TABLE $table" + } + /** * Build a create index SQL statement. *