diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 73fe02ee..cd59e9a6 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -379,6 +379,34 @@ class RedshiftIntegrationSuite extends IntegrationSuiteBase { } } + test("SaveMode.Overwrite with schema-qualified table name (#97)") { + val tableName = s"overwrite_schema_qualified_table_name$randomSuffix" + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))), + StructType(StructField("a", IntegerType) :: Nil)) + try { + // Ensure that the table exists: + df.write + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .mode(SaveMode.ErrorIfExists) + .save() + assert(DefaultJDBCWrapper.tableExists(conn, s"PUBLIC.$tableName")) + // Try overwriting that table while using the schema-qualified table name: + df.write + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", s"PUBLIC.$tableName") + .option("tempdir", tempDir) + .mode(SaveMode.Overwrite) + .save() + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + test("SaveMode.Overwrite with non-existent table") { testRoundtripSaveAndLoad( s"overwrite_non_existent_table$randomSuffix", diff --git a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala index 7e1cf498..a6ed44a7 100644 --- a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala @@ -77,7 +77,7 @@ class DefaultSource(jdbcWrapper: JDBCWrapper, s3ClientFactory: AWSCredentials => def tableExists: Boolean = { val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) try { - jdbcWrapper.tableExists(conn, table) + jdbcWrapper.tableExists(conn, table.toString) } finally { conn.close() } diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index 754b431a..78be9840 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -80,12 +80,27 @@ private[redshift] object Parameters { /** * The Redshift table to be used as the target when loading or writing data. */ - def table: Option[String] = parameters.get("dbtable") + def table: Option[TableName] = parameters.get("dbtable").map(_.trim).flatMap { dbtable => + // We technically allow queries to be passed using `dbtable` as long as they are wrapped + // in parentheses. Valid SQL identifiers may contain parentheses but cannot begin with them, + // so there is no ambiguity in ignoring subqeries here and leaving their handling up to + // the `query` function defined below. + if (dbtable.startsWith("(") && dbtable.endsWith(")")) { + None + } else { + Some(TableName.parseFromEscaped(dbtable)) + } + } /** * The Redshift query to be used as the target when loading data. */ - def query: Option[String] = parameters.get("query") + def query: Option[String] = parameters.get("query").orElse { + parameters.get("dbtable") + .map(_.trim) + .filter(t => t.startsWith("(") && t.endsWith(")")) + .map(t => t.drop(1).dropRight(1)) + } /** * A JDBC URL, of the format: diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index a138208f..08886d66 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -51,7 +51,8 @@ private[redshift] case class RedshiftRelation( override lazy val schema: StructType = { userSchema.getOrElse { - val tableNameOrSubquery = params.query.map(q => s"($q)").orElse(params.table).get + val tableNameOrSubquery = + params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl) try { jdbcWrapper.resolveTable(conn, tableNameOrSubquery) @@ -136,7 +137,7 @@ private[redshift] case class RedshiftRelation( // Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape // any single quotes that appear in the query itself val tableNameOrSubquery: String = { - val unescaped = params.query.map(q => s"($q)").orElse(params.table).get + val unescaped = params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get unescaped.replace("'", "\\'") } s"SELECT $columnList FROM $tableNameOrSubquery $whereClause" diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 5ebcfe23..2dae7e0a 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -81,28 +81,31 @@ private[redshift] class RedshiftWriter( */ private def withStagingTable( conn: Connection, - table: String, + table: TableName, action: (String) => Unit) { val randomSuffix = Math.abs(Random.nextInt()).toString - val tempTable = s"${table}_staging_$randomSuffix" - val backupTable = s"${table}_backup_$randomSuffix" + val tempTable = + table.copy(unescapedTableName = s"${table.unescapedTableName}_staging_$randomSuffix") + val backupTable = + table.copy(unescapedTableName = s"${table.unescapedTableName}_backup_$randomSuffix") log.info("Loading new Redshift data to: " + tempTable) log.info("Existing data will be backed up in: " + backupTable) try { - action(tempTable) + action(tempTable.toString) - if (jdbcWrapper.tableExists(conn, table)) { + if (jdbcWrapper.tableExists(conn, table.toString)) { conn.prepareStatement( s""" | BEGIN; - | ALTER TABLE $table RENAME TO $backupTable; - | ALTER TABLE $tempTable RENAME TO $table; + | ALTER TABLE $table RENAME TO ${backupTable.escapedTableName}; + | ALTER TABLE $tempTable RENAME TO ${table.escapedTableName}; | DROP TABLE $backupTable; | END; """.stripMargin.trim).execute() } else { - conn.prepareStatement(s"ALTER TABLE $tempTable RENAME TO $table").execute() + conn.prepareStatement( + s"ALTER TABLE $tempTable RENAME TO ${table.escapedTableName}").execute() } } finally { conn.prepareStatement(s"DROP TABLE IF EXISTS $tempTable").execute() diff --git a/src/main/scala/com/databricks/spark/redshift/TableName.scala b/src/main/scala/com/databricks/spark/redshift/TableName.scala new file mode 100644 index 00000000..d4a3d12e --- /dev/null +++ b/src/main/scala/com/databricks/spark/redshift/TableName.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2015 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import scala.collection.mutable.ArrayBuffer + +/** + * Wrapper class for representing the name of a Redshift table. + */ +private[redshift] case class TableName(unescapedSchemaName: String, unescapedTableName: String) { + private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' + def escapedSchemaName: String = quote(unescapedSchemaName) + def escapedTableName: String = quote(unescapedTableName) + override def toString: String = s"$escapedSchemaName.$escapedTableName" +} + +private[redshift] object TableName { + /** + * Parses a table name which is assumed to have been escaped according to Redshift's rules for + * delimited identifiers. + */ + def parseFromEscaped(str: String): TableName = { + def dropOuterQuotes(s: String) = + if (s.startsWith("\"") && s.endsWith("\"")) s.drop(1).dropRight(1) else s + def unescapeQuotes(s: String) = s.replace("\"\"", "\"") + def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s)) + splitByDots(str) match { + case Seq(tableName) => TableName("PUBLIC", unescape(tableName)) + case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName)) + case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'") + } + } + + /** + * Split by dots (.) while obeying our identifier quoting rules in order to allow dots to appear + * inside of quoted identifiers. + */ + private def splitByDots(str: String): Seq[String] = { + val parts: ArrayBuffer[String] = ArrayBuffer.empty + val sb = new StringBuilder + var inQuotes: Boolean = false + for (c <- str) c match { + case '"' => + // Note that double quotes are escaped by pairs of double quotes (""), so we don't need + // any extra code to handle them; we'll be back in inQuotes=true after seeing the pair. + sb.append('"') + inQuotes = !inQuotes + case '.' => + if (!inQuotes) { + parts.append(sb.toString()) + sb.clear() + } else { + sb.append('.') + } + case other => + sb.append(other) + } + if (sb.nonEmpty) { + parts.append(sb.toString()) + } + parts + } +} diff --git a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala index 41ede140..471d42df 100644 --- a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala @@ -26,7 +26,7 @@ class ParametersSuite extends FunSuite with Matchers { test("Minimal valid parameter map is accepted") { val params = Map( "tempdir" -> "s3://foo/bar", - "dbtable" -> "test_table", + "dbtable" -> "test_schema.test_table", "url" -> "jdbc:redshift://foo/bar") val mergedParams = Parameters.mergeParameters(params) @@ -34,7 +34,7 @@ class ParametersSuite extends FunSuite with Matchers { mergedParams.rootTempDir should startWith (params("tempdir")) mergedParams.createPerQueryTempDir() should startWith (params("tempdir")) mergedParams.jdbcUrl shouldBe params("url") - mergedParams.table shouldBe Some(params("dbtable")) + mergedParams.table shouldBe Some(TableName("test_schema", "test_table")) // Check that the defaults have been added Parameters.DEFAULT_PARAMETERS foreach { diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index fddcef21..1aa50f86 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -133,12 +133,13 @@ class RedshiftSourceSuite "UNLOAD \\('SELECT \"testbyte\", \"testbool\", \"testdate\", \"testdouble\"," + " \"testfloat\", \"testint\", \"testlong\", \"testshort\", \"teststring\", " + "\"testtimestamp\" " + - "FROM test_table '\\) " + + "FROM \"PUBLIC\".\"test_table\" '\\) " + "TO '.*' " + "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + "ESCAPE").r - val mockRedshift = - new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) // Assert that we've loaded and converted all data in the test file val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) @@ -190,7 +191,7 @@ class RedshiftSourceSuite test("DefaultSource supports simple column filtering") { val expectedQuery = ( - "UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM test_table '\\) " + + "UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " + "TO '.*' " + "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + "ESCAPE").r @@ -217,7 +218,7 @@ class RedshiftSourceSuite // scalastyle:off val expectedQuery = ( "UNLOAD \\('SELECT \"testbyte\", \"testbool\" " + - "FROM test_table " + + "FROM \"PUBLIC\".\"test_table\" " + "WHERE \"testbool\" = true " + "AND \"teststring\" = \\\\'Unicode\\\\'\\\\'s樂趣\\\\' " + "AND \"testdouble\" > 1000.0 " + @@ -228,8 +229,9 @@ class RedshiftSourceSuite "WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " + "ESCAPE").r // scalastyle:on - val mockRedshift = - new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) // Construct the source with a custom schema val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) @@ -262,21 +264,23 @@ class RedshiftSourceSuite "distkey" -> "testint") val expectedCommands = Seq( - "DROP TABLE IF EXISTS test_table_staging_.*".r, - "CREATE TABLE IF NOT EXISTS test_table_staging.* DISTSTYLE KEY DISTKEY \\(testint\\).*".r, - "COPY test_table_staging_.*".r, - "GRANT SELECT ON test_table_staging.+ TO jeremy".r, + "DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r, + ("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table_staging.*" + + " DISTSTYLE KEY DISTKEY \\(testint\\).*").r, + "COPY \"PUBLIC\"\\.\"test_table_staging_.*\"".r, + "GRANT SELECT ON \"PUBLIC\"\\.\"test_table_staging.+\" TO jeremy".r, """ | BEGIN; - | ALTER TABLE test_table RENAME TO test_table_backup_.*; - | ALTER TABLE test_table_staging_.* RENAME TO test_table; - | DROP TABLE test_table_backup_.*; + | ALTER TABLE "PUBLIC"\."test_table" RENAME TO "test_table_backup_.*"; + | ALTER TABLE "PUBLIC"\."test_table_staging_.*" RENAME TO "test_table"; + | DROP TABLE "PUBLIC"\."test_table_backup_.*"; | END; """.stripMargin.trim.r, - "DROP TABLE IF EXISTS test_table_staging_.*".r) + "DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r) - val mockRedshift = - new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) val relation = RedshiftRelation( mockRedshift.jdbcWrapper, @@ -297,8 +301,9 @@ class RedshiftSourceSuite } test("Cannot write table with column names that become ambiguous under case insensitivity") { - val mockRedshift = - new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema)) val schema = StructType(Seq(StructField("a", IntegerType), StructField("A", IntegerType))) val df = testSqlContext.createDataFrame(sc.emptyRDD[Row], schema) @@ -317,15 +322,15 @@ class RedshiftSourceSuite val mockRedshift = new MockRedshift( defaultParams("url"), - Map("test_table" -> TestUtils.testSchema), - jdbcQueriesThatShouldFail = Seq("COPY test_table_staging_.*".r)) + Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema), + jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table_staging_.*\"".r)) val expectedCommands = Seq( - "DROP TABLE IF EXISTS test_table_staging_.*".r, - "CREATE TABLE IF NOT EXISTS test_table_staging.*".r, - "COPY test_table_staging_.*".r, + "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r, + "CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r, + "COPY \"PUBLIC\".\"test_table_staging_.*\"".r, ".*FROM stl_load_errors.*".r, - "DROP TABLE IF EXISTS test_table_staging.*".r + "DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r ) val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) @@ -340,10 +345,12 @@ class RedshiftSourceSuite test("Append SaveMode doesn't destroy existing data") { val expectedCommands = - Seq("CREATE TABLE IF NOT EXISTS test_table .*".r, - "COPY test_table .*".r) + Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r, + "COPY \"PUBLIC\".\"test_table\" .*".r) - val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) val savedDf = @@ -373,13 +380,15 @@ class RedshiftSourceSuite val createTableCommand = DefaultRedshiftWriter.createTableSql(df, MergedParameters.apply(defaultParams)).trim val expectedCreateTableCommand = - """CREATE TABLE IF NOT EXISTS test_table ("long_str" VARCHAR(512),""" + + """CREATE TABLE IF NOT EXISTS "PUBLIC"."test_table" ("long_str" VARCHAR(512),""" + """ "short_str" VARCHAR(10), "default_str" TEXT)""" assert(createTableCommand === expectedCreateTableCommand) } test("Respect SaveMode.ErrorIfExists when table exists") { - val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) val errIfExistsSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) intercept[Exception] { errIfExistsSource.createRelation( @@ -390,7 +399,9 @@ class RedshiftSourceSuite } test("Do nothing when table exists if SaveMode = Ignore") { - val mockRedshift = new MockRedshift(defaultParams("url"), Map(defaultParams("dbtable") -> null)) + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) val ignoreSource = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) ignoreSource.createRelation(testSqlContext, SaveMode.Ignore, defaultParams, expectedDataDF) mockRedshift.verifyThatConnectionsWereClosed() diff --git a/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala b/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala new file mode 100644 index 00000000..24c935f3 --- /dev/null +++ b/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2015 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import org.scalatest.FunSuite + +class TableNameSuite extends FunSuite { + test("TableName.parseFromEscaped") { + assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar")) + assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo")) + assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo")) + assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar")) + // Dots (.) can also appear inside of valid identifiers. + assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz")) + assert(TableName.parseFromEscaped("\"foo\"\".bar\".baz") === TableName("foo\".bar", "baz")) + } + + test("TableName.toString") { + assert(TableName("foo", "bar").toString === """"foo"."bar"""") + assert(TableName("PUBLIC", "bar").toString === """"PUBLIC"."bar"""") + assert(TableName("\"foo\"", "bar").toString === "\"\"\"foo\"\"\".\"bar\"") + } +}