Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions project/SparkRedshiftBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,19 @@ object SparkRedshiftBuild extends Build {
"org.scalamock" %% "scalamock-scalatest-support" % "3.2" % "test"
),
libraryDependencies ++= Seq(
"org.apache.hadoop" % "hadoop-client" % testHadoopVersion.value % "test",
"org.apache.spark" %% "spark-core" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client"),
"org.apache.spark" %% "spark-sql" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client"),
"org.apache.spark" %% "spark-hive" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client")
"org.apache.hadoop" % "hadoop-client" % testHadoopVersion.value % "test" force(),
"org.apache.spark" %% "spark-core" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force(),
"org.apache.spark" %% "spark-sql" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force(),
"org.apache.spark" %% "spark-hive" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force()
),
ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := {
if (scalaBinaryVersion.value == "2.10") false
else false
},
logBuffered := false,
// Display full-length stacktraces from ScalaTest:
testOptions in Test += Tests.Argument("-oF")
testOptions in Test += Tests.Argument("-oF"),
fork in Test := true,
javaOptions in Test ++= Seq("-Xms512M", "-Xmx2048M", "-XX:MaxPermSize=2048M")
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2015 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import org.apache.spark.sql.Row

/**
* Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see
* http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html
*/
class DecimalIntegrationSuite extends IntegrationSuiteBase {

private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = {
test(s"reading DECIMAL($precision, $scale") {
val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix"
val expectedRows =
decimalStrings.map(d => Row(if (d == null) null else Conversions.parseDecimal(d)))
try {
conn.createStatement().executeUpdate(
s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))")
for (x <- decimalStrings) {
conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)")
}
conn.commit()
assert(DefaultJDBCWrapper.tableExists(conn, tableName))
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("aws_access_key_id", AWS_ACCESS_KEY_ID)
.option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY)
.load()
checkAnswer(loadedDf, expectedRows)
checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows)
} finally {
conn.prepareStatement(s"drop table if exists $tableName").executeUpdate()
conn.commit()
}
}
}

testReadingDecimals(19, 0, Seq(
// Max and min values of DECIMAL(19, 0) column according to Redshift docs:
"9223372036854775807", // 2^63 - 1
"-9223372036854775807",
"0",
"12345678910",
null
))

testReadingDecimals(19, 4, Seq(
"922337203685477.5807",
"-922337203685477.5807",
"0",
"1234567.8910",
null
))

testReadingDecimals(38, 4, Seq(
"922337203685477.5808",
"9999999999999999999999999999999999.0000",
"-9999999999999999999999999999999999.0000",
"0",
"1234567.8910",
null
))
}
112 changes: 112 additions & 0 deletions src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2015 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import java.net.URI
import java.sql.Connection

import scala.util.Random

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.test.TestHiveContext
import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers}


/**
* Base class for writing integration tests which run against a real Redshift cluster.
*/
trait IntegrationSuiteBase
extends QueryTest
with Matchers
with BeforeAndAfterAll
with BeforeAndAfterEach {

private def loadConfigFromEnv(envVarName: String): String = {
Option(System.getenv(envVarName)).getOrElse {
fail(s"Must set $envVarName environment variable")
}
}

// The following configurations must be set in order to run these tests. In Travis, these
// environment variables are set using Travis's encrypted environment variables feature:
// http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables

// JDBC URL listed in the AWS console (should not contain username and password).
protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL")
protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER")
protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD")
protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID")
protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY")
// Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space').
private val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE")
require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL")

protected val jdbcUrl: String = {
s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD"
}

/**
* Random suffix appended appended to table and directory names in order to avoid collisions
* between separate Travis builds.
*/
protected val randomSuffix: String = Math.abs(Random.nextLong()).toString

protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/"

/**
* Spark Context with Hadoop file overridden to point at our local test data file for this suite,
* no-matter what temp directory was generated and requested.
*/
protected var sc: SparkContext = _
protected var sqlContext: SQLContext = _
protected var conn: Connection = _

override def beforeAll(): Unit = {
super.beforeAll()
sc = new SparkContext("local", "RedshiftSourceSuite")
conn = DefaultJDBCWrapper.getConnector("com.amazon.redshift.jdbc4.Driver", jdbcUrl)
}

override def afterAll(): Unit = {
try {
val conf = new Configuration()
conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
val fs = FileSystem.get(URI.create(tempDir), conf)
fs.delete(new Path(tempDir), true)
fs.close()
} finally {
try {
conn.close()
} finally {
try {
sc.stop()
} finally {
super.afterAll()
}
}
}
}

override protected def beforeEach(): Unit = {
super.beforeEach()
sqlContext = new TestHiveContext(sc)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,78 +16,22 @@

package com.databricks.spark.redshift

import java.net.URI
import java.sql.{SQLException, Connection}
import java.sql.SQLException

import scala.util.Random

import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkContext
import org.apache.spark.sql.{AnalysisException, Row, SQLContext, SaveMode}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.apache.spark.sql.types._

/**
* End-to-end tests which run against a real Redshift cluster.
*/
class RedshiftIntegrationSuite
extends QueryTest
with Matchers
with BeforeAndAfterAll
with BeforeAndAfterEach {

private def loadConfigFromEnv(envVarName: String): String = {
Option(System.getenv(envVarName)).getOrElse {
fail(s"Must set $envVarName environment variable")
}
}

// The following configurations must be set in order to run these tests. In Travis, these
// environment variables are set using Travis's encrypted environment variables feature:
// http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables

// JDBC URL listed in the AWS console (should not contain username and password).
private val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL")
private val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER")
private val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD")
private val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID")
private val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY")
// Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space').
private val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE")
require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL")

private val jdbcUrl: String = {
s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD"
}

/**
* Random suffix appended appended to table and directory names in order to avoid collisions
* between separate Travis builds.
*/
private val randomSuffix: String = Math.abs(Random.nextLong()).toString

private val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/"

/**
* Spark Context with Hadoop file overridden to point at our local test data file for this suite,
* no-matter what temp directory was generated and requested.
*/
private var sc: SparkContext = _
private var sqlContext: SQLContext = _
private var conn: Connection = _
class RedshiftIntegrationSuite extends IntegrationSuiteBase {

private val test_table: String = s"test_table_$randomSuffix"
private val test_table2: String = s"test_table2_$randomSuffix"
private val test_table3: String = s"test_table3_$randomSuffix"

override def beforeAll(): Unit = {
super.beforeAll()
sc = new SparkContext("local", "RedshiftSourceSuite")

conn = DefaultJDBCWrapper.getConnector("com.amazon.redshift.jdbc4.Driver", jdbcUrl)

conn.prepareStatement("drop table if exists test_table").executeUpdate()
conn.prepareStatement("drop table if exists test_table2").executeUpdate()
Expand Down Expand Up @@ -133,31 +77,17 @@ class RedshiftIntegrationSuite

override def afterAll(): Unit = {
try {
val conf = new Configuration()
conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID)
conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY)
val fs = FileSystem.get(URI.create(tempDir), conf)
fs.delete(new Path(tempDir), true)
fs.close()
conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
conn.prepareStatement(s"drop table if exists $test_table2").executeUpdate()
conn.prepareStatement(s"drop table if exists $test_table3").executeUpdate()
conn.commit()
} finally {
try {
conn.prepareStatement(s"drop table if exists $test_table").executeUpdate()
conn.prepareStatement(s"drop table if exists $test_table2").executeUpdate()
conn.prepareStatement(s"drop table if exists $test_table3").executeUpdate()
conn.commit()
conn.close()
} finally {
try {
sc.stop()
} finally {
super.afterAll()
}
}
super.afterAll()
}
}

override def beforeEach(): Unit = {
sqlContext = new TestHiveContext(sc)
super.beforeEach()
sqlContext.sql(
s"""
| create temporary table test_table(
Expand Down
12 changes: 11 additions & 1 deletion src/main/scala/com/databricks/spark/redshift/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.databricks.spark.redshift

import java.sql.Timestamp
import java.text.{DateFormat, FieldPosition, ParsePosition, SimpleDateFormat}
import java.text.{DecimalFormat, DateFormat, FieldPosition, ParsePosition, SimpleDateFormat}
import java.util.Date

import org.apache.spark.sql.types._
Expand Down Expand Up @@ -91,6 +91,15 @@ private[redshift] object Conversions {
else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'")
}

private[this] val redshiftDecimalFormat: DecimalFormat = new DecimalFormat()
redshiftDecimalFormat.setParseBigDecimal(true)

/**
* Parse a decimal using Redshift's UNLOAD decimal syntax
*/
def parseDecimal(s: String): java.math.BigDecimal = {
redshiftDecimalFormat.parse(s).asInstanceOf[java.math.BigDecimal]
}
/**
* Construct a Row from the given array of strings, retrieved from Redshift UNLOAD.
* The schema will be used for type mappings.
Expand All @@ -105,6 +114,7 @@ private[redshift] object Conversions {
case DateType => parseDate(data)
case DoubleType => data.toDouble
case FloatType => data.toFloat
case dt: DecimalType => parseDecimal(data)
case IntegerType => data.toInt
case LongType => data.toLong
case ShortType => data.toShort
Expand Down