diff --git a/build.sbt b/build.sbt index 918627ec..f0cad803 100644 --- a/build.sbt +++ b/build.sbt @@ -41,6 +41,8 @@ libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % Test libraryDependencies += "org.scalatest" %% "scalatest" % "2.1.5" % Test +libraryDependencies += "org.apache.commons" % "commons-csv" % "1.1" + libraryDependencies += "org.scalamock" %% "scalamock-scalatest-support" % "3.2" % Test ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := { diff --git a/src/main/scala/com/databricks/spark/redshift/Conversions.scala b/src/main/scala/com/databricks/spark/redshift/Conversions.scala index cfe15ecd..bb5b3371 100644 --- a/src/main/scala/com/databricks/spark/redshift/Conversions.scala +++ b/src/main/scala/com/databricks/spark/redshift/Conversions.scala @@ -35,10 +35,91 @@ private object RedshiftBooleanParser extends JavaTokenParsers { def parseRedshiftBoolean(s: String): Boolean = parse(TRUE | FALSE, s).get } +/** + * Utility methods responsible for extracting information from data contained within dataframe in order to generate + * a schema compatible with Redshift. + */ +object MetaSchema { + /** + * Map-Reduce task to calculate the longest string length for each row, in each string column in the dataframe. + * + * Note: This is used to generate N for the VARCHAR(N) field in the table schema to be loaded into Redshift. + * + * TODO: This should only be called once per load into Redshift. A cache, TraversableOnce, or some similar + * structure should be used to enforce this function only being called once. + * + * @param df DataFrame to be processed + * @return A Map[String, Int] representing an assocition between the column name and the length of that column's + * longest string + */ + private[redshift] def mapStrLengths(df:DataFrame) : Map[String, Int] = { + val schema:StructType = df.schema + + // For each row, filter the string columns and calculate the string length + // TODO: Other optimization strategies may be possible + val stringLengths = df.flatMap(row => + schema.collect { + case StructField(columnName, StringType, _, _) => (columnName, getStrLength(row, columnName)) + } + ).reduceByKey(Math.max(_, _)) + + stringLengths.collect().toMap + } + + /** + * Calculate the string length in columnName for the provided Row. Defensively returns 0 if the provided + * columnName is not a string column. + * + * This is a collaborator method to make the mapStrLengths function more readable, and should not be used elsewhere. + * + * @param row Reference to a row of a dataframe + * @param columnName Name of the column + * @return Length of the string in cell, falling back to 0 if null or no string is present. + */ + private[redshift] def getStrLength(row:Row, columnName:String): Int = { + row.getAs[String](columnName) match { + case field:String => field.length() + case _ => 0 + } + } + + /** + * Adds a "maxLength" -> Int field to column metadata. + * + * @param metadata metadata for a dataframe column + * @param length Length limit for content within that column + * @return new metadata object with added field + */ + private[redshift] def setStrLength(metadata:Metadata, length:Int) : Metadata = { + new MetadataBuilder().withMetadata(metadata).putLong("maxLength", length).build() + } + + /** + * Iterate through each column in the schema that is a string, storing the longest string length in that columns' + * metadata for later usage. + */ + def computeEnhancedDf(df: DataFrame): DataFrame = { + // 1. Perform a full scan of each string column, storing it's maximum string length within a Map + val stringLengthsByColumn = mapStrLengths(df) + + // 2. Generate an enhanced schema, with the metadata for each string column + val enhancedSchema = StructType( + df.schema map { + case StructField(name, StringType, nullable, meta) => + StructField(name, StringType, nullable, setStrLength(meta, stringLengthsByColumn(name))) + case other => other + } + ) + + // 3. Construct a new dataframe with a schema containing metadata with string lengths + df.sqlContext.createDataFrame(df.rdd, enhancedSchema) + } +} + /** * Data type conversions for Redshift unloaded data */ -private [redshift] object Conversions { +private[redshift] object Conversions { // Imports and exports with Redshift require that timestamps are represented // as strings, using the following formats @@ -58,7 +139,7 @@ private [redshift] object Conversions { } override def parse(source: String, pos: ParsePosition): Date = { - if(source.length < PATTERN_WITH_MILLIS.length) { + if (source.length < PATTERN_WITH_MILLIS.length) { redshiftTimestampFormatWithoutMillis.parse(source, pos) } else { redshiftTimestampFormatWithMillis.parse(source, pos) @@ -127,4 +208,4 @@ private [redshift] object Conversions { sqlContext.createDataFrame(df.rdd, schema) } -} \ No newline at end of file +} diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index 4dcdc518..fcb35886 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -162,6 +162,21 @@ private [redshift] object Parameters extends Logging { */ def postActions = parameters("postactions").split(";") + /** + * How the maximum length for each column containing text is to be inferred (i.e. the 'N' in VARCHAR(N)). + * Redshift doesn't support variable length TEXT like other SQL dialects, so columns containing text of unbounded + * length must either be processed to determine the longest possible string in all rows for that column, or truncated + * to a fixed amount. A number may also be passed to this parameter allowing for the maximum number of characters. + * + * Examples: + * AUTO + * TRUNCATE(50) + * MAXLENGTH(4096) + * + * Defaults to 'AUTO' + */ + def stringLengths = parameters("stringlengths").toString().toUpperCase() + /** * Looks up "aws_access_key_id" and "aws_secret_access_key" in the parameter map * and generates a credentials string for Redshift. If no credentials have been provided, diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 9e6983d8..b08f0a73 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -19,6 +19,8 @@ package com.databricks.spark.redshift import java.sql.{Connection, SQLException} import java.util.Properties +import org.apache.spark.sql.types._ + import scala.util.Random import com.databricks.spark.redshift.Parameters.MergedParameters @@ -32,11 +34,53 @@ import org.apache.spark.sql.{DataFrame, SQLContext} */ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { + def varcharStr(meta: Metadata): String = { + // TODO: Need fallback for max length + val maxLength: Long = meta.getLong("maxLength") + + maxLength match { + case _: Long => s"VARCHAR($maxLength)" + } + } + + /** + * Compute A Redshift compatible schema string for this dataframe. + */ + def schemaString(df: DataFrame): String = { + val sb = new StringBuilder() + + df.schema.fields foreach { + field => { + val name = field.name + val typ: String = + field match { + case StructField(_, IntegerType, _, _) => "INTEGER" + case StructField(_, LongType, _, _) => "BIGINT" + case StructField(_, DoubleType, _, _) => "DOUBLE PRECISION" + case StructField(_, FloatType, _, _) => "REAL" + case StructField(_, ShortType, _, _) => "INTEGER" + case StructField(_, BooleanType, _, _) => "BOOLEAN" + case StructField(_, StringType, _, metadata) => varcharStr(metadata) + case StructField(_, TimestampType, _, _) => "TIMESTAMP" + case StructField(_, DateType, _, _) => "DATE" + case StructField(_, t: DecimalType, _, _) => s"DECIMAL(${t.precision}},${t.scale}})" + case StructField(_, ByteType, _, _) => "BYTE" // TODO: REPLACEME (UNSUPPORTED BY REDSHIFT) + case StructField(_, BinaryType, _, _) => "BLOB" // TODO: REPLACEME (UNSUPPORTED BY REDSHIFT) + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + } + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + } + } + if (sb.length < 2) "" else sb.substring(2) + } + /** * Generate CREATE TABLE statement for Redshift */ - def createTableSql(data: DataFrame, params: MergedParameters) : String = { - val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl) + def createTableSql(data: DataFrame, params: MergedParameters): String = { + var schemaSql = schemaString(MetaSchema.computeEnhancedDf(data)) + val distStyleDef = params.distStyle match { case Some(style) => s"DISTSTYLE $style" case None => "" @@ -47,7 +91,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { } val sortKeyDef = params.sortKeySpec.getOrElse("") - s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef" + s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef".trim } /** @@ -63,7 +107,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { * Sets up a staging table then runs the given action, passing the temporary table name * as a parameter. */ - def withStagingTable(conn:Connection, params: MergedParameters, action: (String) => Unit) { + def withStagingTable(conn: Connection, params: MergedParameters, action: (String) => Unit) { val randomSuffix = Math.abs(Random.nextInt()).toString val tempTable = s"${params.table}_staging_$randomSuffix" val backupTable = s"${params.table}_backup_$randomSuffix" @@ -93,10 +137,10 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { * Perform the Redshift load, including deletion of existing data in the case of an overwrite, * and creating the table if it doesn't already exist. */ - def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters) : Unit = { + def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters): Unit = { // Overwrites must drop the table, in case there has been a schema update - if(params.overwrite) { + if (params.overwrite) { val deleteExisting = conn.prepareStatement(s"DROP TABLE IF EXISTS ${params.table}") deleteExisting.execute() } @@ -114,7 +158,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { // Execute postActions params.postActions.foreach(action => { - val actionSql = if(action.contains("%s")) action.format(params.table) else action + val actionSql = if (action.contains("%s")) action.format(params.table) else action log.info("Executing postAction: " + actionSql) conn.prepareStatement(actionSql).execute() }) @@ -124,19 +168,21 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { * Serialize temporary data to S3, ready for Redshift COPY */ def unloadData(sqlContext: SQLContext, data: DataFrame, tempPath: String): Unit = { - Conversions.datesToTimestamps(sqlContext, data).write.format("com.databricks.spark.avro").save(tempPath) + val enrichedData = Conversions.datesToTimestamps(sqlContext, data) // TODO .extractStringColumnLengths + + enrichedData.write.format("com.databricks.spark.avro").save(tempPath) } /** * Write a DataFrame to a Redshift table, using S3 and Avro serialization */ - def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters) : Unit = { + def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters): Unit = { val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, new Properties()).apply() try { - if(params.overwrite && params.useStagingTable) { + if (params.overwrite && params.useStagingTable) { withStagingTable(conn, params, table => { - val updatedParams = MergedParameters(params.parameters updated ("dbtable", table)) + val updatedParams = MergedParameters(params.parameters updated("dbtable", table)) unloadData(sqlContext, data, updatedParams.tempPath) doRedshiftLoad(conn, data, updatedParams) }) diff --git a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala index 50ac931a..eb295556 100644 --- a/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ConversionsSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.Row /** * Unit test for data type conversions */ -class ConversionsSuite extends FunSuite { +class ConversionsSuite extends MockDatabaseSuite { - val convertRow = Conversions.rowConverter(TestUtils.testSchema) + val convertRow = Conversions.rowConverter(testSchema) test("Data should be correctly converted") { val doubleMin = Double.MinValue.toString @@ -51,7 +51,7 @@ class ConversionsSuite extends FunSuite { } test("Row conversion handles null values") { - val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] + val emptyRow = List.fill(testSchema.length)(null).toArray[String] assert(convertRow(emptyRow) == Row(emptyRow: _*)) } } diff --git a/src/test/scala/com/databricks/spark/redshift/MockDatabaseSuite.scala b/src/test/scala/com/databricks/spark/redshift/MockDatabaseSuite.scala new file mode 100644 index 00000000..cd781d5e --- /dev/null +++ b/src/test/scala/com/databricks/spark/redshift/MockDatabaseSuite.scala @@ -0,0 +1,127 @@ +package com.databricks.spark.redshift + +import java.sql.{SQLException, PreparedStatement, Connection} + +import com.databricks.spark.redshift.TestUtils._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.jdbc.JDBCWrapper +import org.apache.spark.sql.types._ + +import org.scalamock.scalatest.MockFactory +import org.scalatest.FunSuite + +import scala.util.matching.Regex + +class MockDatabaseSuite extends FunSuite with MockFactory { + /** + * Makes a field for the test schema + */ + def makeField(name: String, typ: DataType) = { + val md = (new MetadataBuilder).putString("name", name).build() + StructField(name, typ, nullable = true, metadata = md) + } + + /** + * Simple schema that includes all data types we support + */ + lazy val testSchema = + StructType( + Seq( + makeField("testByte", ByteType), + makeField("testBool", BooleanType), + makeField("testDate", DateType), + makeField("testDouble", DoubleType), + makeField("testFloat", FloatType), + makeField("testInt", IntegerType), + makeField("testLong", LongType), + makeField("testShort", ShortType), + makeField("testString", StringType), + makeField("testTimestamp", TimestampType)) + ) + + /** + * Expected parsed output corresponding to the output of testData. + */ + val testData = + Array( + Row(1.toByte, true, toTimestamp(2015, 6, 1, 0, 0, 0), 1234152.123124981, + 1.0f, 42, 1239012341823719L, 23, "Unicode是樂趣", toTimestamp(2015, 6, 1, 0, 0, 0, 1)), + Row(1.toByte, false, toTimestamp(2015, 6, 2, 0, 0, 0), 0.0, 0.0f, 42, 1239012341823719L, -13, "asdf", + toTimestamp(2015, 6, 2, 0, 0, 0, 0)), + Row(0.toByte, null, toTimestamp(2015, 6, 3, 0, 0, 0), 0.0, -1.0f, 4141214, 1239012341823719L, null, "f", + toTimestamp(2015, 6, 3, 0, 0, 0)), + Row(0.toByte, false, null, -1234152.123124981, 100000.0f, null, 1239012341823719L, 24, "___|_123", null), + Row(List.fill(10)(null): _*)) + + def successfulStatement(pattern: Regex): PreparedStatement = { + val mockedConnection = mock[Connection] + + val mockedStatement = mock[PreparedStatement] + (mockedConnection.prepareStatement(_: String)) + .expects(where {(sql: String) => pattern.findFirstMatchIn(sql).nonEmpty}) + .returning(mockedStatement) + (mockedStatement.execute _).expects().returning(true) + + mockedStatement + } + + def failedStatement(pattern: Regex) : PreparedStatement = { + val mockedConnection = mock[Connection] + + val mockedStatement = mock[PreparedStatement] + (mockedConnection.prepareStatement(_: String)) + .expects(where {(sql: String) => pattern.findFirstMatchIn(sql).nonEmpty}) + .returning(mockedStatement) + + (mockedStatement.execute _) + .expects() + .throwing(new SQLException("Mocked Error")) + + mockedStatement + } + + /** + * Set up a mocked JDBCWrapper instance that expects a sequence of queries matching the given + * regular expressions will be executed, and that the connection returned will be closed. + */ + def mockJdbcWrapper(expectedUrl: String, expectedQueries: Seq[Regex]): JDBCWrapper = { + val jdbcWrapper = mock[JDBCWrapper] + val mockedConnection = mock[Connection] + + (jdbcWrapper.getConnector _).expects(*, expectedUrl, *).returning(() => mockedConnection) + + inSequence { + expectedQueries foreach { r => + val mockedStatement = mock[PreparedStatement] + (mockedConnection.prepareStatement(_: String)) + .expects(where {(sql: String) => r.findFirstMatchIn(sql).nonEmpty}) + .returning(mockedStatement) + (mockedStatement.execute _).expects().returning(true) + } + + (mockedConnection.close _).expects() + } + + jdbcWrapper + } + + /** + * Prepare the JDBC wrapper for an UNLOAD test. + */ + def prepareUnloadTest(params: Map[String, String]) = { + val jdbcUrl = params("url") + val jdbcWrapper = mockJdbcWrapper(jdbcUrl, Seq("UNLOAD.*".r)) + + // We expect some extra calls to the JDBC wrapper, + // to register the driver and retrieve the schema. + (jdbcWrapper.registerDriver _) + .expects(*) + .anyNumberOfTimes() + (jdbcWrapper.resolveTable _) + .expects(jdbcUrl, "test_table", *) + .returning(testSchema) + .anyNumberOfTimes() + + jdbcWrapper + } +} diff --git a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala index edea5bbb..a92f16b5 100644 --- a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala @@ -67,4 +67,29 @@ class ParametersSuite extends FunSuite with Matchers { checkMerge(Map("tempdir" -> "s3://foo/bar", "url" -> "jdbc:postgresql://foo/bar")) checkMerge(Map("dbtable" -> "test_table", "tempdir" -> "s3://foo/bar")) } + + test("Parameters for Redshift text/string column conversions") { + val params = + Map( + "tempdir" -> "s3://foo/bar", + "dbtable" -> "test_table", + "url" -> "jdbc:postgresql://foo/bar") + + Parameters.mergeParameters(params) + + val auto = params ++ Map("stringlengths" -> "auto") // Default + Parameters.mergeParameters(auto).stringLengths should equal("AUTO") + + val truncate = params ++ Map("stringlengths" -> "truncate") + Parameters.mergeParameters(truncate).stringLengths should equal("TRUNCATE") + + val optimistic = params ++ Map("stringlengths" -> "maxlength") + Parameters.mergeParameters(optimistic).stringLengths should equal("MAXLENGTH") + + val none = params ++ Map("stringlengths" -> "default") + Parameters.mergeParameters(none).stringLengths should equal("DEFAULT") + + val manual = params ++ Map("stringlengths" -> "manual") + Parameters.mergeParameters(manual).stringLengths should equal("MANUAL") + } } diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala index 02ebaca7..1a704ed2 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftInputFormatSuite.scala @@ -123,7 +123,8 @@ class RedshiftInputFormatSuite extends FunSuite with BeforeAndAfterAll { withTempDir { dir => val testRecords = Set( Seq("a\n", "TX", 1, 1.0, 1000L, 200000000000L), - Seq("b", "CA", 2, 2.0, 2000L, 1231412314L)) + Seq("b", "CA", 2, 2.0, 2000L, 1231412314L) + ) val escaped = escape(testRecords.map(_.map(_.toString)), DEFAULT_DELIMITER) writeToFile(escaped, new File(dir, "part-00000")) @@ -135,6 +136,7 @@ class RedshiftInputFormatSuite extends FunSuite with BeforeAndAfterAll { val srdd = sqlContext.redshiftFile( dir.toString, "name varchar(10) state text id integer score float big_score numeric(4, 0) some_long bigint") + val expectedSchema = StructType(Seq( StructField("name", StringType, nullable = true), StructField("state", StringType, nullable = true), @@ -142,13 +144,14 @@ class RedshiftInputFormatSuite extends FunSuite with BeforeAndAfterAll { StructField("score", DoubleType, nullable = true), StructField("big_score", LongType, nullable = true), StructField("some_long", LongType, nullable = true))) - assert(srdd.schema === expectedSchema) + val parsed = srdd.map { case Row(name: String, state: String, id: Int, score: Double, bigScore: Long, someLong: Long) => Seq(name, state, id, score, bigScore, someLong) }.collect().toSet + assert(srdd.schema === expectedSchema) assert(parsed === testRecords) } } diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index 0cb2566f..3467d6a6 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -23,7 +23,7 @@ import scala.util.matching.Regex import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.InputFormat -import org.scalamock.scalatest.MockFactory + import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.SparkContext @@ -51,9 +51,8 @@ class TestContext extends SparkContext("local", "RedshiftSourceSuite") { * Tests main DataFrame loading and writing functionality */ class RedshiftSourceSuite - extends FunSuite + extends MockDatabaseSuite with Matchers - with MockFactory with BeforeAndAfterAll { /** @@ -104,65 +103,13 @@ class RedshiftSourceSuite super.afterAll() } - /** - * Set up a mocked JDBCWrapper instance that expects a sequence of queries matching the given - * regular expressions will be executed, and that the connection returned will be closed. - */ - def mockJdbcWrapper(expectedUrl: String, expectedQueries: Seq[Regex]): JDBCWrapper = { - val jdbcWrapper = mock[JDBCWrapper] - val mockedConnection = mock[Connection] - - (jdbcWrapper.getConnector _).expects(*, expectedUrl, *).returning(() => mockedConnection) - - inSequence { - expectedQueries foreach { r => - val mockedStatement = mock[PreparedStatement] - (mockedConnection.prepareStatement(_: String)) - .expects(where {(sql: String) => r.findFirstMatchIn(sql).nonEmpty}) - .returning(mockedStatement) - (mockedStatement.execute _).expects().returning(true) - } - - (mockedConnection.close _).expects() - } - - jdbcWrapper - } - - /** - * Prepare the JDBC wrapper for an UNLOAD test. - */ - def prepareUnloadTest(params: Map[String, String]) = { - val jdbcUrl = params("url") - val jdbcWrapper = mockJdbcWrapper(jdbcUrl, Seq("UNLOAD.*".r)) - - // We expect some extra calls to the JDBC wrapper, - // to register the driver and retrieve the schema. - (jdbcWrapper.registerDriver _) - .expects(*) - .anyNumberOfTimes() - (jdbcWrapper.resolveTable _) - .expects(jdbcUrl, "test_table", *) - .returning(TestUtils.testSchema) - .anyNumberOfTimes() - - jdbcWrapper - } - test("DefaultSource can load Redshift UNLOAD output to a DataFrame") { - - val params = Map("url" -> "jdbc:postgresql://foo/bar", - "tempdir" -> "tmp", - "dbtable" -> "test_table", - "aws_access_key_id" -> "test1", - "aws_secret_access_key" -> "test2") - - val jdbcWrapper = prepareUnloadTest(params) + val jdbcWrapper = prepareUnloadTest(TestUtils.params) val testSqlContext = new SQLContext(sc) // Assert that we've loaded and converted all data in the test file val source = new DefaultSource(jdbcWrapper) - val relation = source.createRelation(testSqlContext, params) + val relation = source.createRelation(testSqlContext, TestUtils.params) val df = testSqlContext.baseRelationToDataFrame(relation) df.rdd.collect() zip expectedData foreach { @@ -172,19 +119,12 @@ class RedshiftSourceSuite } test("DefaultSource supports simple column filtering") { - - val params = Map("url" -> "jdbc:postgresql://foo/bar", - "tempdir" -> "tmp", - "dbtable" -> "test_table", - "aws_access_key_id" -> "test1", - "aws_secret_access_key" -> "test2") - - val jdbcWrapper = prepareUnloadTest(params) + val jdbcWrapper = prepareUnloadTest(TestUtils.params) val testSqlContext = new SQLContext(sc) // Construct the source with a custom schema val source = new DefaultSource(jdbcWrapper) - val relation = source.createRelation(testSqlContext, params, TestUtils.testSchema) + val relation = source.createRelation(testSqlContext, TestUtils.params, testSchema) val rdd = relation.asInstanceOf[PrunedFilteredScan].buildScan(Array("testByte", "testBool"), Array.empty[Filter]) val prunedExpectedValues = @@ -200,19 +140,12 @@ class RedshiftSourceSuite } test("DefaultSource supports user schema, pruned and filtered scans") { - - val params = Map("url" -> "jdbc:postgresql://foo/bar", - "tempdir" -> "tmp", - "dbtable" -> "test_table", - "aws_access_key_id" -> "test1", - "aws_secret_access_key" -> "test2") - - val jdbcWrapper = prepareUnloadTest(params) + val jdbcWrapper = prepareUnloadTest(TestUtils.params) val testSqlContext = new SQLContext(sc) // Construct the source with a custom schema val source = new DefaultSource(jdbcWrapper) - val relation = source.createRelation(testSqlContext, params, TestUtils.testSchema) + val relation = source.createRelation(testSqlContext, TestUtils.params, testSchema) // Define a simple filter to only include a subset of rows val filters: Array[Filter] = @@ -247,7 +180,7 @@ class RedshiftSourceSuite "distkey" -> "testInt") val rdd = sc.parallelize(expectedData.toSeq) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) val expectedCommands = Seq("DROP TABLE IF EXISTS test_table_staging_.*".r, @@ -294,7 +227,7 @@ class RedshiftSourceSuite "usestagingtable" -> "true") val rdd = sc.parallelize(expectedData.toSeq) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) val jdbcWrapper = mock[JDBCWrapper] val mockedConnection = mock[Connection] @@ -303,6 +236,7 @@ class RedshiftSourceSuite .expects(*, jdbcUrl, *) .returning(() => mockedConnection) + // TODO: extract to outer class def successfulStatement(pattern: Regex): PreparedStatement = { val mockedStatement = mock[PreparedStatement] (mockedConnection.prepareStatement(_: String)) @@ -313,6 +247,7 @@ class RedshiftSourceSuite mockedStatement } + // TODO: extract to outer class def failedStatement(pattern: Regex) : PreparedStatement = { val mockedStatement = mock[PreparedStatement] (mockedConnection.prepareStatement(_: String)) @@ -369,15 +304,9 @@ class RedshiftSourceSuite val testSqlContext = new SQLContext(sc) val jdbcUrl = "jdbc:postgresql://foo/bar" - val params = - Map("url" -> jdbcUrl, - "tempdir" -> tempDir, - "dbtable" -> "test_table", - "aws_access_key_id" -> "test1", - "aws_secret_access_key" -> "test2") val rdd = sc.parallelize(expectedData.toSeq) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) val expectedCommands = Seq("CREATE TABLE IF NOT EXISTS test_table .*".r, @@ -396,7 +325,7 @@ class RedshiftSourceSuite .anyNumberOfTimes() val source = new DefaultSource(jdbcWrapper) - val savedDf = source.createRelation(testSqlContext, SaveMode.Append, params, df) + val savedDf = source.createRelation(testSqlContext, SaveMode.Append, TestUtils.params, df) // This test is "appending" to an empty table, so we expect all our test data to be // the only content in the returned data frame @@ -419,7 +348,7 @@ class RedshiftSourceSuite "aws_secret_access_key" -> "test2") val rdd = sc.parallelize(expectedData.toSeq) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) // Check that SaveMode.ErrorIfExists throws an exception @@ -447,7 +376,7 @@ class RedshiftSourceSuite "aws_secret_access_key" -> "test2") val rdd = sc.parallelize(expectedData.toSeq) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) // Check that SaveMode.Ignore does nothing @@ -467,7 +396,7 @@ class RedshiftSourceSuite val rdd = sc.parallelize(expectedData) val testSqlContext = new SQLContext(sc) - val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema) + val df = testSqlContext.createDataFrame(rdd, testSchema) intercept[Exception] { df.saveAsRedshiftTable(invalid) @@ -478,9 +407,17 @@ class RedshiftSourceSuite } } + test("Basic string field extraction") { + val rdd = sc.parallelize(expectedData) + val testSqlContext = new SQLContext(sc) + val df = testSqlContext.createDataFrame(rdd, testSchema) + + val dfMetaSchema = MetaSchema.computeEnhancedDf(df) + + assert(dfMetaSchema.schema("testString").metadata.getLong("maxLength") == 10) + } + test("DefaultSource has default constructor, required by Data Source API") { new DefaultSource() } - - } diff --git a/src/test/scala/com/databricks/spark/redshift/SchemaGenerationSuite.scala b/src/test/scala/com/databricks/spark/redshift/SchemaGenerationSuite.scala new file mode 100644 index 00000000..a58d56f0 --- /dev/null +++ b/src/test/scala/com/databricks/spark/redshift/SchemaGenerationSuite.scala @@ -0,0 +1,98 @@ +package com.databricks.spark.redshift + +import java.io.File +import java.sql.Connection + +import com.databricks.spark.redshift.TestUtils._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql.jdbc.JDBCWrapper +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext, Row} +import org.scalatest.{BeforeAndAfterAll, Matchers} +import scala.util.matching.Regex + +class SchemaGenerationSuite extends MockDatabaseSuite with Matchers with BeforeAndAfterAll { + // TODO: DRY ME + + /** + * Temporary folder for unloading data to + */ + val tempDir = { + var dir = File.createTempFile("spark_redshift_tests", "") + dir.delete() + dir.mkdirs() + dir.toURI.toString + } + + var sc: SparkContext = _ + var testSqlContext: SQLContext = _ + var df: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + sc = new TestContext + testSqlContext = new SQLContext(sc) + + df = testSqlContext.createDataFrame(sc.parallelize(testData), testSchema) + } + + override def afterAll(): Unit = { + val temp = new File(tempDir) + val tempFiles = temp.listFiles() + if (tempFiles != null) tempFiles foreach { + case f => if (f != null) f.delete() + } + temp.delete() + + sc.stop() + super.afterAll() + } + + test("generating the table creation SQL") { + val expectedCommands = Seq("CREATE TABLE IF NOT EXISTS test_table .*".r) + + val mockWrapper = mock[JDBCWrapper] + val mockedConnection = mock[Connection] + // val mockWrapper = mockJdbcWrapper("url", Seq.empty[Regex]) + + val rsOutput: RedshiftWriter = new RedshiftWriter(mockWrapper) + + val params = Parameters.mergeParameters(TestUtils.params) + + rsOutput.createTableSql(df, params) should equal("CREATE TABLE IF NOT EXISTS test_table (testByte BYTE , testBool BOOLEAN , testDate DATE , testDouble DOUBLE PRECISION , testFloat REAL , testInt INTEGER , testLong BIGINT , testShort INTEGER , testString VARCHAR(10) , testTimestamp TIMESTAMP ) DISTSTYLE EVEN") + } + + test("Metaschema") { + val enhancedDataframe = MetaSchema.computeEnhancedDf(df) + + enhancedDataframe.schema("testString").metadata.getLong("maxLength") should equal(10) + } + + test("schema with multiple string columns") { + val schema = StructType( + Seq( + makeField("col1", StringType), + makeField("col2", StringType), + makeField("col3", StringType), + makeField("col4", IntegerType) + ) + ) + + val data = Array( + Row(null, null, null, null), + Row(null, "", "", 0), + Row(null, "A longer string", "", 0), + Row(null, "2", "", null) + ) + + val stringDf = testSqlContext.createDataFrame(sc.parallelize(data), schema) + + val enhancedDf = MetaSchema.computeEnhancedDf(stringDf) + + enhancedDf.schema("col1").metadata.getLong("maxLength") should equal(0) + enhancedDf.schema("col2").metadata.getLong("maxLength") should equal(15) + enhancedDf.schema("col3").metadata.getLong("maxLength") should equal(0) + } +} diff --git a/src/test/scala/com/databricks/spark/redshift/TestUtils.scala b/src/test/scala/com/databricks/spark/redshift/TestUtils.scala index b950467d..88d28bf3 100644 --- a/src/test/scala/com/databricks/spark/redshift/TestUtils.scala +++ b/src/test/scala/com/databricks/spark/redshift/TestUtils.scala @@ -19,38 +19,21 @@ package com.databricks.spark.redshift import java.sql.Timestamp import java.util.Calendar +import com.databricks.spark.redshift.Parameters.MergedParameters import org.apache.spark.sql.types._ /** * Helpers for Redshift tests that require common mocking */ object TestUtils { - - /** - * Makes a field for the test schema - */ - def makeField(name: String, typ: DataType) = { - val md = (new MetadataBuilder).putString("name", name).build() - StructField(name, typ, nullable = true, metadata = md) + def params: Map[String, String] = { + Map("url" -> "jdbc:postgresql://foo/bar", + "tempdir" -> "tmp", + "dbtable" -> "test_table", + "aws_access_key_id" -> "test1", + "aws_secret_access_key" -> "test2") } - /** - * Simple schema that includes all data types we support - */ - lazy val testSchema = - StructType( - Seq( - makeField("testByte", ByteType), - makeField("testBool", BooleanType), - makeField("testDate", DateType), - makeField("testDouble", DoubleType), - makeField("testFloat", FloatType), - makeField("testInt", IntegerType), - makeField("testLong", LongType), - makeField("testShort", ShortType), - makeField("testString", StringType), - makeField("testTimestamp", TimestampType))) - /** * Convert date components to a millisecond timestamp */