diff --git a/project/SparkRedshiftBuild.scala b/project/SparkRedshiftBuild.scala index 3dc1c9e0..07bf3d65 100644 --- a/project/SparkRedshiftBuild.scala +++ b/project/SparkRedshiftBuild.scala @@ -64,10 +64,10 @@ object SparkRedshiftBuild extends Build { "org.scalamock" %% "scalamock-scalatest-support" % "3.2" % "test" ), libraryDependencies ++= Seq( - "org.apache.hadoop" % "hadoop-client" % testHadoopVersion.value % "test", - "org.apache.spark" %% "spark-core" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client"), - "org.apache.spark" %% "spark-sql" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client"), - "org.apache.spark" %% "spark-hive" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") + "org.apache.hadoop" % "hadoop-client" % testHadoopVersion.value % "test" force(), + "org.apache.spark" %% "spark-core" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force(), + "org.apache.spark" %% "spark-sql" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force(), + "org.apache.spark" %% "spark-hive" % testSparkVersion.value % "test" exclude("org.apache.hadoop", "hadoop-client") force() ), ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := { if (scalaBinaryVersion.value == "2.10") false @@ -75,6 +75,8 @@ object SparkRedshiftBuild extends Build { }, logBuffered := false, // Display full-length stacktraces from ScalaTest: - testOptions in Test += Tests.Argument("-oF") + testOptions in Test += Tests.Argument("-oF"), + fork in Test := true, + javaOptions in Test ++= Seq("-Xms512M", "-Xmx2048M", "-XX:MaxPermSize=2048M") ) } diff --git a/src/it/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala new file mode 100644 index 00000000..c224ab1c --- /dev/null +++ b/src/it/scala/com/databricks/spark/redshift/DecimalIntegrationSuite.scala @@ -0,0 +1,82 @@ +/* + * Copyright 2015 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.Row + +/** + * Integration tests for decimal support. For a reference on Redshift's DECIMAL type, see + * http://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html + */ +class DecimalIntegrationSuite extends IntegrationSuiteBase { + + private def testReadingDecimals(precision: Int, scale: Int, decimalStrings: Seq[String]): Unit = { + test(s"reading DECIMAL($precision, $scale") { + val tableName = s"reading_decimal_${precision}_${scale}_$randomSuffix" + val expectedRows = + decimalStrings.map(d => Row(if (d == null) null else Conversions.parseDecimal(d))) + try { + conn.createStatement().executeUpdate( + s"CREATE TABLE $tableName (x DECIMAL($precision, $scale))") + for (x <- decimalStrings) { + conn.createStatement().executeUpdate(s"INSERT INTO $tableName VALUES ($x)") + } + conn.commit() + assert(DefaultJDBCWrapper.tableExists(conn, tableName)) + val loadedDf = sqlContext.read + .format("com.databricks.spark.redshift") + .option("url", jdbcUrl) + .option("dbtable", tableName) + .option("tempdir", tempDir) + .option("aws_access_key_id", AWS_ACCESS_KEY_ID) + .option("aws_secret_access_key", AWS_SECRET_ACCESS_KEY) + .load() + checkAnswer(loadedDf, expectedRows) + checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows) + } finally { + conn.prepareStatement(s"drop table if exists $tableName").executeUpdate() + conn.commit() + } + } + } + + testReadingDecimals(19, 0, Seq( + // Max and min values of DECIMAL(19, 0) column according to Redshift docs: + "9223372036854775807", // 2^63 - 1 + "-9223372036854775807", + "0", + "12345678910", + null + )) + + testReadingDecimals(19, 4, Seq( + "922337203685477.5807", + "-922337203685477.5807", + "0", + "1234567.8910", + null + )) + + testReadingDecimals(38, 4, Seq( + "922337203685477.5808", + "9999999999999999999999999999999999.0000", + "-9999999999999999999999999999999999.0000", + "0", + "1234567.8910", + null + )) +} diff --git a/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala b/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala new file mode 100644 index 00000000..6aef9800 --- /dev/null +++ b/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala @@ -0,0 +1,112 @@ +/* + * Copyright 2015 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import java.net.URI +import java.sql.Connection + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.hive.test.TestHiveContext +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers} + + +/** + * Base class for writing integration tests which run against a real Redshift cluster. + */ +trait IntegrationSuiteBase + extends QueryTest + with Matchers + with BeforeAndAfterAll + with BeforeAndAfterEach { + + private def loadConfigFromEnv(envVarName: String): String = { + Option(System.getenv(envVarName)).getOrElse { + fail(s"Must set $envVarName environment variable") + } + } + + // The following configurations must be set in order to run these tests. In Travis, these + // environment variables are set using Travis's encrypted environment variables feature: + // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables + + // JDBC URL listed in the AWS console (should not contain username and password). + protected val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL") + protected val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER") + protected val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD") + protected val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID") + protected val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY") + // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space'). + private val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE") + require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL") + + protected val jdbcUrl: String = { + s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD" + } + + /** + * Random suffix appended appended to table and directory names in order to avoid collisions + * between separate Travis builds. + */ + protected val randomSuffix: String = Math.abs(Random.nextLong()).toString + + protected val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/" + + /** + * Spark Context with Hadoop file overridden to point at our local test data file for this suite, + * no-matter what temp directory was generated and requested. + */ + protected var sc: SparkContext = _ + protected var sqlContext: SQLContext = _ + protected var conn: Connection = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sc = new SparkContext("local", "RedshiftSourceSuite") + conn = DefaultJDBCWrapper.getConnector("com.amazon.redshift.jdbc4.Driver", jdbcUrl) + } + + override def afterAll(): Unit = { + try { + val conf = new Configuration() + conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) + conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) + val fs = FileSystem.get(URI.create(tempDir), conf) + fs.delete(new Path(tempDir), true) + fs.close() + } finally { + try { + conn.close() + } finally { + try { + sc.stop() + } finally { + super.afterAll() + } + } + } + } + + override protected def beforeEach(): Unit = { + super.beforeEach() + sqlContext = new TestHiveContext(sc) + } +} diff --git a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala index c4429df0..9e9cf1f6 100644 --- a/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala +++ b/src/it/scala/com/databricks/spark/redshift/RedshiftIntegrationSuite.scala @@ -16,68 +16,15 @@ package com.databricks.spark.redshift -import java.net.URI -import java.sql.{SQLException, Connection} +import java.sql.SQLException -import scala.util.Random - -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, Matchers} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.SparkContext import org.apache.spark.sql.{AnalysisException, Row, SQLContext, SaveMode} -import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.types._ /** * End-to-end tests which run against a real Redshift cluster. */ -class RedshiftIntegrationSuite - extends QueryTest - with Matchers - with BeforeAndAfterAll - with BeforeAndAfterEach { - - private def loadConfigFromEnv(envVarName: String): String = { - Option(System.getenv(envVarName)).getOrElse { - fail(s"Must set $envVarName environment variable") - } - } - - // The following configurations must be set in order to run these tests. In Travis, these - // environment variables are set using Travis's encrypted environment variables feature: - // http://docs.travis-ci.com/user/environment-variables/#Encrypted-Variables - - // JDBC URL listed in the AWS console (should not contain username and password). - private val AWS_REDSHIFT_JDBC_URL: String = loadConfigFromEnv("AWS_REDSHIFT_JDBC_URL") - private val AWS_REDSHIFT_USER: String = loadConfigFromEnv("AWS_REDSHIFT_USER") - private val AWS_REDSHIFT_PASSWORD: String = loadConfigFromEnv("AWS_REDSHIFT_PASSWORD") - private val AWS_ACCESS_KEY_ID: String = loadConfigFromEnv("TEST_AWS_ACCESS_KEY_ID") - private val AWS_SECRET_ACCESS_KEY: String = loadConfigFromEnv("TEST_AWS_SECRET_ACCESS_KEY") - // Path to a directory in S3 (e.g. 's3n://bucket-name/path/to/scratch/space'). - private val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE") - require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL") - - private val jdbcUrl: String = { - s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD" - } - - /** - * Random suffix appended appended to table and directory names in order to avoid collisions - * between separate Travis builds. - */ - private val randomSuffix: String = Math.abs(Random.nextLong()).toString - - private val tempDir: String = AWS_S3_SCRATCH_SPACE + randomSuffix + "/" - - /** - * Spark Context with Hadoop file overridden to point at our local test data file for this suite, - * no-matter what temp directory was generated and requested. - */ - private var sc: SparkContext = _ - private var sqlContext: SQLContext = _ - private var conn: Connection = _ +class RedshiftIntegrationSuite extends IntegrationSuiteBase { private val test_table: String = s"test_table_$randomSuffix" private val test_table2: String = s"test_table2_$randomSuffix" @@ -85,9 +32,6 @@ class RedshiftIntegrationSuite override def beforeAll(): Unit = { super.beforeAll() - sc = new SparkContext("local", "RedshiftSourceSuite") - - conn = DefaultJDBCWrapper.getConnector("com.amazon.redshift.jdbc4.Driver", jdbcUrl) conn.prepareStatement("drop table if exists test_table").executeUpdate() conn.prepareStatement("drop table if exists test_table2").executeUpdate() @@ -133,31 +77,17 @@ class RedshiftIntegrationSuite override def afterAll(): Unit = { try { - val conf = new Configuration() - conf.set("fs.s3n.awsAccessKeyId", AWS_ACCESS_KEY_ID) - conf.set("fs.s3n.awsSecretAccessKey", AWS_SECRET_ACCESS_KEY) - val fs = FileSystem.get(URI.create(tempDir), conf) - fs.delete(new Path(tempDir), true) - fs.close() + conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() + conn.prepareStatement(s"drop table if exists $test_table2").executeUpdate() + conn.prepareStatement(s"drop table if exists $test_table3").executeUpdate() + conn.commit() } finally { - try { - conn.prepareStatement(s"drop table if exists $test_table").executeUpdate() - conn.prepareStatement(s"drop table if exists $test_table2").executeUpdate() - conn.prepareStatement(s"drop table if exists $test_table3").executeUpdate() - conn.commit() - conn.close() - } finally { - try { - sc.stop() - } finally { - super.afterAll() - } - } + super.afterAll() } } override def beforeEach(): Unit = { - sqlContext = new TestHiveContext(sc) + super.beforeEach() sqlContext.sql( s""" | create temporary table test_table( diff --git a/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/src/main/scala/com/databricks/spark/redshift/Conversions.scala index c3190cca..1e574a56 100644 --- a/src/main/scala/com/databricks/spark/redshift/Conversions.scala +++ b/src/main/scala/com/databricks/spark/redshift/Conversions.scala @@ -17,7 +17,7 @@ package com.databricks.spark.redshift import java.sql.Timestamp -import java.text.{DateFormat, FieldPosition, ParsePosition, SimpleDateFormat} +import java.text.{DecimalFormat, DateFormat, FieldPosition, ParsePosition, SimpleDateFormat} import java.util.Date import org.apache.spark.sql.types._ @@ -91,6 +91,15 @@ private[redshift] object Conversions { else throw new IllegalArgumentException(s"Expected 't' or 'f' but got '$s'") } + private[this] val redshiftDecimalFormat: DecimalFormat = new DecimalFormat() + redshiftDecimalFormat.setParseBigDecimal(true) + + /** + * Parse a decimal using Redshift's UNLOAD decimal syntax + */ + def parseDecimal(s: String): java.math.BigDecimal = { + redshiftDecimalFormat.parse(s).asInstanceOf[java.math.BigDecimal] + } /** * Construct a Row from the given array of strings, retrieved from Redshift UNLOAD. * The schema will be used for type mappings. @@ -105,6 +114,7 @@ private[redshift] object Conversions { case DateType => parseDate(data) case DoubleType => data.toDouble case FloatType => data.toFloat + case dt: DecimalType => parseDecimal(data) case IntegerType => data.toInt case LongType => data.toLong case ShortType => data.toShort