Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Properties
import scala.util.Random

import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkContext
Expand Down Expand Up @@ -86,10 +87,6 @@ class RedshiftIntegrationSuite
override def beforeAll(): Unit = {
super.beforeAll()
sc = new SparkContext("local", "RedshiftSourceSuite")
sc.hadoopConfiguration.set("fs.s3.awsAccessKeyId", AWS_ACCESS_KEY_ID)
sc.hadoopConfiguration.set("fs.s3.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)

conn = DefaultJDBCWrapper.getConnector(
"com.amazon.redshift.jdbc4.Driver", jdbcUrl, new Properties())()
Expand Down Expand Up @@ -138,7 +135,10 @@ class RedshiftIntegrationSuite

override def afterAll(): Unit = {
try {
val fs = FileSystem.get(URI.create(tempDir), sc.hadoopConfiguration)
val conf = new Configuration()
conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
val fs = FileSystem.get(URI.create(tempDir), conf)
fs.delete(new Path(tempDir), true)
fs.close()
} finally {
Expand Down Expand Up @@ -178,6 +178,8 @@ class RedshiftIntegrationSuite
| options(
| url \"$jdbcUrl\",
| tempdir \"$tempDir\",
| aws_access_key_id \"$AWS_ACCESS_KEY_ID\",
| aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\",
| dbtable \"$test_table\"
| )
""".stripMargin
Expand All @@ -201,6 +203,8 @@ class RedshiftIntegrationSuite
| options(
| url \"$jdbcUrl\",
| tempdir \"$tempDir\",
| aws_access_key_id \"$AWS_ACCESS_KEY_ID\",
| aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\",
| dbtable \"$test_table2\"
| )
""".stripMargin
Expand All @@ -224,6 +228,8 @@ class RedshiftIntegrationSuite
| options(
| url \"$jdbcUrl\",
| tempdir \"$tempDir\",
| aws_access_key_id \"$AWS_ACCESS_KEY_ID\",
| aws_secret_access_key \"$AWS_SECRET_ACCESS_KEY\",
| dbtable \"$test_table3\"
| )
""".stripMargin
Expand Down Expand Up @@ -254,6 +260,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", query)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, Seq(Row(1, true)))
}
Expand All @@ -276,6 +284,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("query", query)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, Seq(Row(1, true)))
}
Expand All @@ -286,6 +296,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("query", s"select testbool, count(*) from $test_table group by testbool")
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, Seq(Row(true, 1), Row(false, 2), Row(null, 2)))
}
Expand Down Expand Up @@ -327,6 +339,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.ErrorIfExists)
.save()

Expand All @@ -336,6 +350,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, TestUtils.expectedData)
} finally {
Expand All @@ -354,6 +370,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.ErrorIfExists)
.save()

Expand All @@ -363,6 +381,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
assert(loadedDf.schema.length === 1)
assert(loadedDf.columns === Seq("a"))
Expand All @@ -384,6 +404,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.ErrorIfExists)
.save()
}
Expand All @@ -404,6 +426,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.Overwrite)
.save()

Expand All @@ -413,6 +437,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, TestUtils.expectedData)
} finally {
Expand All @@ -433,6 +459,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", test_table3)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.Append)
.saveAsTable(test_table3)

Expand All @@ -453,6 +481,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", test_table)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.ErrorIfExists)
.saveAsTable(test_table)
}
Expand All @@ -466,6 +496,8 @@ class RedshiftIntegrationSuite
.option("url", jdbcUrl)
.option("dbtable", test_table)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.mode(SaveMode.Ignore)
.saveAsTable(test_table)

Expand Down
30 changes: 22 additions & 8 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ private[redshift] object Parameters extends Logging {
* available.
*/
def credentialsString(configuration: Configuration): String = {
val ((_, accessKeyId), (_, secretAccessKey)) = credentialsTuple(configuration)
s"aws_access_key_id=$accessKeyId;aws_secret_access_key=$secretAccessKey"
}

/**
* Looks up "aws_access_key_id" and "aws_secret_access_key" in the parameter map and generates a
* credentials string for Redshift. If no credentials have been provided, this function will
* instead try using the Hadoop Configuration `fs.* settings` for the provided tempDir scheme,
* and if that also fails, it finally tries AWS DefaultCredentialsProviderChain, which makes
* use of standard system properties, environment variables, or IAM role configuration if
* available.
*/
def setCredentials(configuration: Configuration): Unit = {
val ((accessKeyIdProp, accessKeyId), (secretAccessKeyProp, secretAccessKey)) =
credentialsTuple(configuration)
configuration.setIfUnset(accessKeyIdProp, accessKeyId)
configuration.setIfUnset(secretAccessKeyProp, secretAccessKey)
}

private def credentialsTuple(configuration: Configuration) = {
val scheme = new URI(tempDir).getScheme
val hadoopConfPrefix = s"fs.$scheme"

Expand All @@ -212,14 +232,8 @@ private[redshift] object Parameters extends Logging {
}
}

val credentials = s"aws_access_key_id=$accessKeyId;aws_secret_access_key=$secretAccessKey"

if (parameters.contains("aws_security_token")) {
val securityToken = parameters("aws_security_token")
credentials + s";token=$securityToken"
} else {
credentials
}
((s"$hadoopConfPrefix.awsAccessKeyId", accessKeyId),
(s"$hadoopConfPrefix.awsSecretAccessKey", secretAccessKey))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ package com.databricks.spark.redshift

import java.util.Properties

import com.databricks.spark.redshift.Parameters.MergedParameters

import org.apache.hadoop.conf.Configuration
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

import com.databricks.spark.redshift.Parameters.MergedParameters

/**
* Data Source API implementation for Amazon Redshift database tables
*/
Expand Down Expand Up @@ -87,8 +88,10 @@ private[redshift] case class RedshiftRelation(

private def makeRdd(schema: StructType): RDD[Row] = {
val sc = sqlContext.sparkContext
val hadoopConf = new Configuration(sc.hadoopConfiguration)
params.setCredentials(hadoopConf)
val rdd = sc.newAPIHadoopFile(params.tempPath, classOf[RedshiftInputFormat],
classOf[java.lang.Long], classOf[Array[String]], sc.hadoopConfiguration)
classOf[java.lang.Long], classOf[Array[String]], hadoopConf)
rdd.values.map(Conversions.rowConverter(schema))
}

Expand Down