Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,34 @@ class RedshiftIntegrationSuite extends IntegrationSuiteBase {
}
}

test("SaveMode.Overwrite with schema-qualified table name (#97)") {
val tableName = s"overwrite_schema_qualified_table_name$randomSuffix"
val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("a", IntegerType) :: Nil))
try {
// Ensure that the table exists:
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.mode(SaveMode.ErrorIfExists)
.save()
assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName"))
// Try overwriting that table while using the schema-qualified table name:
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("dbtable", s"PUBLIC.$tableName")
.option("tempdir", tempDir)
.mode(SaveMode.Overwrite)
.save()
} finally {
conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
conn.commit()
}
}

test("SaveMode.Overwrite with non-existent table") {
testRoundtripSaveAndLoad(
s"overwrite_non_existent_table$randomSuffix",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DefaultSource(jdbcWrapper: JDBCWrapper, s3ClientFactory: AWSCredentials =>
def tableExists: Boolean = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
try {
jdbcWrapper.tableExists(conn, table)
jdbcWrapper.tableExists(conn, table.toString)
} finally {
conn.close()
}
Expand Down
19 changes: 17 additions & 2 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,27 @@ private[redshift] object Parameters {
/**
* The Redshift table to be used as the target when loading or writing data.
*/
def table: Option[String] = parameters.get("dbtable")
def table: Option[TableName] = parameters.get("dbtable").map(_.trim).flatMap { dbtable =>
// We technically allow queries to be passed using `dbtable` as long as they are wrapped
// in parentheses. Valid SQL identifiers may contain parentheses but cannot begin with them,
// so there is no ambiguity in ignoring subqeries here and leaving their handling up to
// the `query` function defined below.
if (dbtable.startsWith("(") && dbtable.endsWith(")")) {
None
} else {
Some(TableName.parseFromEscaped(dbtable))
}
}

/**
* The Redshift query to be used as the target when loading data.
*/
def query: Option[String] = parameters.get("query")
def query: Option[String] = parameters.get("query").orElse {
parameters.get("dbtable")
.map(_.trim)
.filter(t => t.startsWith("(") && t.endsWith(")"))
.map(t => t.drop(1).dropRight(1))
}

/**
* A JDBC URL, of the format:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ private[redshift] case class RedshiftRelation(

override lazy val schema: StructType = {
userSchema.getOrElse {
val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get
val tableNameOrSubquery =
params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
try {
jdbcWrapper.resolveTable(conn, tableNameOrSubquery)
Expand Down Expand Up @@ -136,7 +137,7 @@ private[redshift] case class RedshiftRelation(
// Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
// any single quotes that appear in the query itself
val tableNameOrSubquery: String = {
val unescaped = params.query.map(q => s"($q)").orElse(params.table).get
val unescaped = params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get
unescaped.replace("'", "\\'")
}
s"SELECT $columnList FROM $tableNameOrSubquery $whereClause"
Expand Down
19 changes: 11 additions & 8 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,31 @@ private[redshift] class RedshiftWriter(
*/
private def withStagingTable(
conn: Connection,
table: String,
table: TableName,
action: (String) => Unit) {
val randomSuffix = Math.abs(Random.nextInt()).toString
val tempTable = s"${table}_staging_$randomSuffix"
val backupTable = s"${table}_backup_$randomSuffix"
val tempTable =
table.copy(unescapedTableName = s"${table.unescapedTableName}_staging_$randomSuffix")
val backupTable =
table.copy(unescapedTableName = s"${table.unescapedTableName}_backup_$randomSuffix")
log.info("Loading new Redshift data to: " + tempTable)
log.info("Existing data will be backed up in: " + backupTable)

try {
action(tempTable)
action(tempTable.toString)

if (jdbcWrapper.tableExists(conn, table)) {
if (jdbcWrapper.tableExists(conn, table.toString)) {
conn.prepareStatement(
s"""
| BEGIN;
| ALTER TABLE $table RENAME TO $backupTable;
| ALTER TABLE $tempTable RENAME TO $table;
| ALTER TABLE $table RENAME TO ${backupTable.escapedTableName};
| ALTER TABLE $tempTable RENAME TO ${table.escapedTableName};
| DROP TABLE $backupTable;
| END;
""".stripMargin.trim).execute()
} else {
conn.prepareStatement(s"ALTER TABLE $tempTable RENAME TO $table").execute()
conn.prepareStatement(
s"ALTER TABLE $tempTable RENAME TO ${table.escapedTableName}").execute()
}
} finally {
conn.prepareStatement(s"DROP TABLE IF EXISTS $tempTable").execute()
Expand Down
77 changes: 77 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/TableName.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2015 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import scala.collection.mutable.ArrayBuffer

/**
* Wrapper class for representing the name of a Redshift table.
*/
private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) {
private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"'
def escapedSchemaName: String = quote(unescapedSchemaName)
def escapedTableName: String = quote(unescapedTableName)
override def toString: String = s"$escapedSchemaName.$escapedTableName"
}

private[redshift] object TableName {
/**
* Parses a table name which is assumed to have been escaped according to Redshift's rules for
* delimited identifiers.
*/
def parseFromEscaped(str: String): TableName = {
def dropOuterQuotes(s: String) =
if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s
def unescapeQuotes(s: String) = s.replace("\"\"", "\"")
def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s))
splitByDots(str) match {
case Seq(tableName) => TableName("PUBLIC", unescape(tableName))
case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName))
case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'")
}
}

/**
* Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear
* inside of quoted identifiers.
*/
private def splitByDots(str: String): Seq[String] = {
val parts: ArrayBuffer[String] = ArrayBuffer.empty
val sb = new StringBuilder
var inQuotes: Boolean = false
for (c <- str) c match {
case '"' =>
// Note that double quotes are escaped by pairs of double quotes (""), so we don't need
// any extra code to handle them; we'll be back in inQuotes=true after seeing the pair.
sb.append('"')
inQuotes = !inQuotes
case '.' =>
if (!inQuotes) {
parts.append(sb.toString())
sb.clear()
} else {
sb.append('.')
}
case other =>
sb.append(other)
}
if (sb.nonEmpty) {
parts.append(sb.toString())
}
parts
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ class ParametersSuite extends FunSuite with Matchers {
test("Minimal valid parameter map is accepted") {
val params = Map(
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_table",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar")

val mergedParams = Parameters.mergeParameters(params)

mergedParams.rootTempDir should startWith (params("tempdir"))
mergedParams.createPerQueryTempDir() should startWith (params("tempdir"))
mergedParams.jdbcUrl shouldBe params("url")
mergedParams.table shouldBe Some(params("dbtable"))
mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))

// Check that the defaults have been added
Parameters.DEFAULT_PARAMETERS foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ class RedshiftSourceSuite
"UNLOAD \\('SELECT \"testbyte\", \"testbool\", \"testdate\", \"testdouble\"," +
" \"testfloat\", \"testint\", \"testlong\", \"testshort\", \"teststring\", " +
"\"testtimestamp\" " +
"FROM test_table '\\) " +
"FROM \"PUBLIC\".\"test_table\" '\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"ESCAPE").r
val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))

// Assert that we've loaded and converted all data in the test file
val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
Expand Down Expand Up @@ -190,7 +191,7 @@ class RedshiftSourceSuite

test("DefaultSource supports simple column filtering") {
val expectedQuery = (
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM test_table '\\) " +
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"ESCAPE").r
Expand All @@ -217,7 +218,7 @@ class RedshiftSourceSuite
// scalastyle:off
val expectedQuery = (
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" " +
"FROM test_table " +
"FROM \"PUBLIC\".\"test_table\" " +
"WHERE \"testbool\" = true " +
"AND \"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\' " +
"AND \"testdouble\" > 1000.0 " +
Expand All @@ -228,8 +229,9 @@ class RedshiftSourceSuite
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"ESCAPE").r
// scalastyle:on
val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))

// Construct the source with a custom schema
val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
Expand Down Expand Up @@ -262,21 +264,23 @@ class RedshiftSourceSuite
"distkey" -> "testint")

val expectedCommands = Seq(
"DROP TABLE IF EXISTS test_table_staging_.*".r,
"CREATE TABLE IF NOT EXISTS test_table_staging.* DISTSTYLE KEY DISTKEY \\(testint\\).*".r,
"COPY test_table_staging_.*".r,
"GRANT SELECT ON test_table_staging.+ TO jeremy".r,
"DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r,
("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table_staging.*" +
" DISTSTYLE KEY DISTKEY \\(testint\\).*").r,
"COPY \"PUBLIC\"\\.\"test_table_staging_.*\"".r,
"GRANT SELECT ON \"PUBLIC\"\\.\"test_table_staging.+\" TO jeremy".r,
"""
| BEGIN;
| ALTER TABLE test_table RENAME TO test_table_backup_.*;
| ALTER TABLE test_table_staging_.* RENAME TO test_table;
| DROP TABLE test_table_backup_.*;
| ALTER TABLE "PUBLIC"\."test_table" RENAME TO "test_table_backup_.*";
| ALTER TABLE "PUBLIC"\."test_table_staging_.*" RENAME TO "test_table";
| DROP TABLE "PUBLIC"\."test_table_backup_.*";
| END;
""".stripMargin.trim.r,
"DROP TABLE IF EXISTS test_table_staging_.*".r)
"DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r)

val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))

val relation = RedshiftRelation(
mockRedshift.jdbcWrapper,
Expand All @@ -297,8 +301,9 @@ class RedshiftSourceSuite
}

test("Cannot write table with column names that become ambiguous under case insensitivity") {
val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema))

val schema = StructType(Seq(StructField("a", IntegerType), StructField("A", IntegerType)))
val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema)
Expand All @@ -317,15 +322,15 @@ class RedshiftSourceSuite

val mockRedshift = new MockRedshift(
defaultParams("url"),
Map("test_table" -> TestUtils.testSchema),
jdbcQueriesThatShouldFail = Seq("COPY test_table_staging_.*".r))
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema),
jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table_staging_.*\"".r))

val expectedCommands = Seq(
"DROP TABLE IF EXISTS test_table_staging_.*".r,
"CREATE TABLE IF NOT EXISTS test_table_staging.*".r,
"COPY test_table_staging_.*".r,
"DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r,
"CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r,
"COPY \"PUBLIC\".\"test_table_staging_.*\"".r,
".*FROM stl_load_errors.*".r,
"DROP TABLE IF EXISTS test_table_staging.*".r
"DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r
)

val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
Expand All @@ -340,10 +345,12 @@ class RedshiftSourceSuite

test("Append SaveMode doesn't destroy existing data") {
val expectedCommands =
Seq("CREATE TABLE IF NOT EXISTS test_table .*".r,
"COPY test_table .*".r)
Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,
"COPY \"PUBLIC\".\"test_table\" .*".r)

val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))

val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
val savedDf =
Expand Down Expand Up @@ -373,13 +380,15 @@ class RedshiftSourceSuite
val createTableCommand =
DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim
val expectedCreateTableCommand =
"""CREATE TABLE IF NOT EXISTS test_table ("long_str" VARCHAR(512),""" +
"""CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("long_str" VARCHAR(512),""" +
""" "short_str" VARCHAR(10), "default_str" TEXT)"""
assert(createTableCommand === expectedCreateTableCommand)
}

test("Respect SaveMode.ErrorIfExists when table exists") {
val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
val errIfExistsSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
intercept[Exception] {
errIfExistsSource.createRelation(
Expand All @@ -390,7 +399,9 @@ class RedshiftSourceSuite
}

test("Do nothing when table exists if SaveMode = Ignore") {
val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null))
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
val ignoreSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
ignoreSource.createRelation(testSqlContext, SaveMode.Ignore, defaultParams, expectedDataDF)
mockRedshift.verifyThatConnectionsWereClosed()
Expand Down
Loading