diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index 81b11ca3..bcd494e0 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -23,6 +23,7 @@ import java.util.Properties import scala.util.Random import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext @@ -86,10 +87,6 @@ class RedshiftIntegrationSuite override def beforeAll(): Unit = { super.beforeAll() sc = new SparkContext("local", "RedshiftSourceSuite") - sc.hadoopConfiguration.set("fs.s3.awsAccessKeyId", AWS_ACCESS_KEY_ID) - sc.hadoopConfiguration.set("fs.s3.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) - sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) - sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) conn = DefaultJDBCWrapper.getConnector( "com.amazon.redshift.jdbc4.Driver", jdbcUrl, new Properties())() @@ -138,7 +135,10 @@ class RedshiftIntegrationSuite override def afterAll(): Unit = { try { - val fs = FileSystem.get(URI.create(tempDir), sc.hadoopConfiguration) + val conf = new Configuration() + conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) + conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) + val fs = FileSystem.get(URI.create(tempDir), conf) fs.delete(new Path(tempDir), true) fs.close() } finally { @@ -178,6 +178,8 @@ class RedshiftIntegrationSuite | options( | url \"$jdbcUrl\", | tempdir \"$tempDir\", + | aws_access_key_id \"$AWS_ACCESS_KEY_ID\", + | aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\", | dbtable \"$test_table\" | ) """.stripMargin @@ -201,6 +203,8 @@ class RedshiftIntegrationSuite | options( | url \"$jdbcUrl\", | tempdir \"$tempDir\", + | aws_access_key_id \"$AWS_ACCESS_KEY_ID\", + | aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\", | dbtable \"$test_table2\" | ) """.stripMargin @@ -224,6 +228,8 @@ class RedshiftIntegrationSuite | options( | url \"$jdbcUrl\", | tempdir \"$tempDir\", + | aws_access_key_id \"$AWS_ACCESS_KEY_ID\", + | aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\", | dbtable \"$test_table3\" | ) """.stripMargin @@ -254,6 +260,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", query) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() checkAnswer(loadedDf, Seq(Row(1, true))) } @@ -276,6 +284,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("query", query) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() checkAnswer(loadedDf, Seq(Row(1, true))) } @@ -286,6 +296,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("query", s"select testbool, count(*) from $test_table group by testbool") .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() checkAnswer(loadedDf, Seq(Row(true, 1), Row(false, 2), Row(null, 2))) } @@ -327,6 +339,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.ErrorIfExists) .save() @@ -336,6 +350,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() checkAnswer(loadedDf, TestUtils.expectedData) } finally { @@ -354,6 +370,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.ErrorIfExists) .save() @@ -363,6 +381,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() assert(loadedDf.schema.length === 1) assert(loadedDf.columns === Seq("a")) @@ -384,6 +404,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.ErrorIfExists) .save() } @@ -404,6 +426,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.Overwrite) .save() @@ -413,6 +437,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", tableName) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .load() checkAnswer(loadedDf, TestUtils.expectedData) } finally { @@ -433,6 +459,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", test_table3) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.Append) .saveAsTable(test_table3) @@ -453,6 +481,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", test_table) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.ErrorIfExists) .saveAsTable(test_table) } @@ -466,6 +496,8 @@ class RedshiftIntegrationSuite .option("url", jdbcUrl) .option("dbtable", test_table) .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) .mode(SaveMode.Ignore) .saveAsTable(test_table) diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index 5e7a0204..d2e9701c 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -188,6 +188,26 @@ private[redshift] object Parameters extends Logging { * available. */ def credentialsString(configuration: Configuration): String = { + val ((_, accessKeyId), (_, secretAccessKey)) = credentialsTuple(configuration) + s"aws_access_key_id=$accessKeyId;aws_secret_access_key=$secretAccessKey" + } + + /** + * Looks up "aws_access_key_id" and "aws_secret_access_key" in the parameter map and generates a + * credentials string for Redshift. If no credentials have been provided, this function will + * instead try using the Hadoop Configuration `fs.* settings` for the provided tempDir scheme, + * and if that also fails, it finally tries AWS DefaultCredentialsProviderChain, which makes + * use of standard system properties, environment variables, or IAM role configuration if + * available. + */ + def setCredentials(configuration: Configuration): Unit = { + val ((accessKeyIdProp, accessKeyId), (secretAccessKeyProp, secretAccessKey)) = + credentialsTuple(configuration) + configuration.setIfUnset(accessKeyIdProp, accessKeyId) + configuration.setIfUnset(secretAccessKeyProp, secretAccessKey) + } + + private def credentialsTuple(configuration: Configuration) = { val scheme = new URI(tempDir).getScheme val hadoopConfPrefix = s"fs.$scheme" @@ -212,14 +232,8 @@ private[redshift] object Parameters extends Logging { } } - val credentials = s"aws_access_key_id=$accessKeyId;aws_secret_access_key=$secretAccessKey" - - if (parameters.contains("aws_security_token")) { - val securityToken = parameters("aws_security_token") - credentials + s";token=$securityToken" - } else { - credentials - } + ((s"$hadoopConfPrefix.awsAccessKeyId", accessKeyId), + (s"$hadoopConfPrefix.awsSecretAccessKey", secretAccessKey)) } } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index a7bc275d..862c00db 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -18,14 +18,15 @@ package com.databricks.spark.redshift import java.util.Properties -import com.databricks.spark.redshift.Parameters.MergedParameters - +import org.apache.hadoop.conf.Configuration import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import com.databricks.spark.redshift.Parameters.MergedParameters + /** * Data Source API implementation for Amazon Redshift database tables */ @@ -87,8 +88,10 @@ private[redshift] case class RedshiftRelation( private def makeRdd(schema: StructType): RDD[Row] = { val sc = sqlContext.sparkContext + val hadoopConf = new Configuration(sc.hadoopConfiguration) + params.setCredentials(hadoopConf) val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat], - classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration) + classOf[java.lang.Long], classOf[Array[String]], hadoopConf) rdd.values.map(Conversions.rowConverter(schema)) }