Skip to content
Closed
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ libraryDependencies += "com.google.guava" % "guava" % "14.0.1" % Test

libraryDependencies += "org.scalatest" %% "scalatest" % "2.1.5" % Test

libraryDependencies += "org.apache.commons" % "commons-csv" % "1.1"

libraryDependencies += "org.scalamock" %% "scalamock-scalatest-support" % "3.2" % Test

ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := {
Expand Down
87 changes: 84 additions & 3 deletions src/main/scala/com/databricks/spark/redshift/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,91 @@ private object RedshiftBooleanParser extends JavaTokenParsers {
def parseRedshiftBoolean(s: String): Boolean = parse(TRUE | FALSE, s).get
}

/**
* Utility methods responsible for extracting information from data contained within dataframe in order to generate
* a schema compatible with Redshift.
*/
object MetaSchema {
/**
* Map-Reduce task to calculate the longest string length for each row, in each string column in the dataframe.
*
* Note: This is used to generate N for the VARCHAR(N) field in the table schema to be loaded into Redshift.
*
* TODO: This should only be called once per load into Redshift. A cache, TraversableOnce, or some similar
* structure should be used to enforce this function only being called once.
*
* @param df DataFrame to be processed
* @return A Map[String, Int] representing an assocition between the column name and the length of that column's
* longest string
*/
private[redshift] def mapStrLengths(df:DataFrame) : Map[String, Int] = {
val schema:StructType = df.schema

// For each row, filter the string columns and calculate the string length
// TODO: Other optimization strategies may be possible
val stringLengths = df.flatMap(row =>
schema.collect {
case StructField(columnName, StringType, _, _) => (columnName, getStrLength(row, columnName))
}
).reduceByKey(Math.max(_, _))

stringLengths.collect().toMap
}

/**
* Calculate the string length in columnName for the provided Row. Defensively returns 0 if the provided
* columnName is not a string column.
*
* This is a collaborator method to make the mapStrLengths function more readable, and should not be used elsewhere.
*
* @param row Reference to a row of a dataframe
* @param columnName Name of the column
* @return Length of the string in cell, falling back to 0 if null or no string is present.
*/
private[redshift] def getStrLength(row:Row, columnName:String): Int = {
row.getAs[String](columnName) match {
case field:String => field.length()
case _ => 0
}
}

/**
* Adds a "maxLength" -> Int field to column metadata.
*
* @param metadata metadata for a dataframe column
* @param length Length limit for content within that column
* @return new metadata object with added field
*/
private[redshift] def setStrLength(metadata:Metadata, length:Int) : Metadata = {
new MetadataBuilder().withMetadata(metadata).putLong("maxLength", length).build()
}

/**
* Iterate through each column in the schema that is a string, storing the longest string length in that columns'
* metadata for later usage.
*/
def computeEnhancedDf(df: DataFrame): DataFrame = {
// 1. Perform a full scan of each string column, storing it's maximum string length within a Map
val stringLengthsByColumn = mapStrLengths(df)

// 2. Generate an enhanced schema, with the metadata for each string column
val enhancedSchema = StructType(
df.schema map {
case StructField(name, StringType, nullable, meta) =>
StructField(name, StringType, nullable, setStrLength(meta, stringLengthsByColumn(name)))
case other => other
}
)

// 3. Construct a new dataframe with a schema containing metadata with string lengths
df.sqlContext.createDataFrame(df.rdd, enhancedSchema)
}
}

/**
* Data type conversions for Redshift unloaded data
*/
private [redshift] object Conversions {
private[redshift] object Conversions {

// Imports and exports with Redshift require that timestamps are represented
// as strings, using the following formats
Expand All @@ -58,7 +139,7 @@ private [redshift] object Conversions {
}

override def parse(source: String, pos: ParsePosition): Date = {
if(source.length < PATTERN_WITH_MILLIS.length) {
if (source.length < PATTERN_WITH_MILLIS.length) {
redshiftTimestampFormatWithoutMillis.parse(source, pos)
} else {
redshiftTimestampFormatWithMillis.parse(source, pos)
Expand Down Expand Up @@ -127,4 +208,4 @@ private [redshift] object Conversions {

sqlContext.createDataFrame(df.rdd, schema)
}
}
}
15 changes: 15 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,21 @@ private [redshift] object Parameters extends Logging {
*/
def postActions = parameters("postactions").split(";")

/**
* How the maximum length for each column containing text is to be inferred (i.e. the 'N' in VARCHAR(N)).
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configuration parameters are still a work in progress.

* Redshift doesn't support variable length TEXT like other SQL dialects, so columns containing text of unbounded
* length must either be processed to determine the longest possible string in all rows for that column, or truncated
* to a fixed amount. A number may also be passed to this parameter allowing for the maximum number of characters.
*
* Examples:
* AUTO
* TRUNCATE(50)
* MAXLENGTH(4096)
*
* Defaults to 'AUTO'
*/
def stringLengths = parameters("stringlengths").toString().toUpperCase()

/**
* Looks up "aws_access_key_id" and "aws_secret_access_key" in the parameter map
* and generates a credentials string for Redshift. If no credentials have been provided,
Expand Down
68 changes: 57 additions & 11 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.databricks.spark.redshift
import java.sql.{Connection, SQLException}
import java.util.Properties

import org.apache.spark.sql.types._

import scala.util.Random

import com.databricks.spark.redshift.Parameters.MergedParameters
Expand All @@ -32,11 +34,53 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
*/
class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {

def varcharStr(meta: Metadata): String = {
// TODO: Need fallback for max length
val maxLength: Long = meta.getLong("maxLength")

maxLength match {
case _: Long => s"VARCHAR($maxLength)"
}
}

/**
* Compute A Redshift compatible schema string for this dataframe.
*/
def schemaString(df: DataFrame): String = {
val sb = new StringBuilder()

df.schema.fields foreach {
field => {
val name = field.name
val typ: String =
field match {
case StructField(_, IntegerType, _, _) => "INTEGER"
case StructField(_, LongType, _, _) => "BIGINT"
case StructField(_, DoubleType, _, _) => "DOUBLE PRECISION"
case StructField(_, FloatType, _, _) => "REAL"
case StructField(_, ShortType, _, _) => "INTEGER"
case StructField(_, BooleanType, _, _) => "BOOLEAN"
case StructField(_, StringType, _, metadata) => varcharStr(metadata)
case StructField(_, TimestampType, _, _) => "TIMESTAMP"
case StructField(_, DateType, _, _) => "DATE"
case StructField(_, t: DecimalType, _, _) => s"DECIMAL(${t.precision}},${t.scale}})"
case StructField(_, ByteType, _, _) => "BYTE" // TODO: REPLACEME (UNSUPPORTED BY REDSHIFT)
case StructField(_, BinaryType, _, _) => "BLOB" // TODO: REPLACEME (UNSUPPORTED BY REDSHIFT)
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
}
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
}
if (sb.length < 2) "" else sb.substring(2)
}

/**
* Generate CREATE TABLE statement for Redshift
*/
def createTableSql(data: DataFrame, params: MergedParameters) : String = {
val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl)
def createTableSql(data: DataFrame, params: MergedParameters): String = {
var schemaSql = schemaString(MetaSchema.computeEnhancedDf(data))

val distStyleDef = params.distStyle match {
case Some(style) => s"DISTSTYLE $style"
case None => ""
Expand All @@ -47,7 +91,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
}
val sortKeyDef = params.sortKeySpec.getOrElse("")

s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef"
s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef".trim
}

/**
Expand All @@ -63,7 +107,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Sets up a staging table then runs the given action, passing the temporary table name
* as a parameter.
*/
def withStagingTable(conn:Connection, params: MergedParameters, action: (String) => Unit) {
def withStagingTable(conn: Connection, params: MergedParameters, action: (String) => Unit) {
val randomSuffix = Math.abs(Random.nextInt()).toString
val tempTable = s"${params.table}_staging_$randomSuffix"
val backupTable = s"${params.table}_backup_$randomSuffix"
Expand Down Expand Up @@ -93,10 +137,10 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Perform the Redshift load, including deletion of existing data in the case of an overwrite,
* and creating the table if it doesn't already exist.
*/
def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters) : Unit = {
def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters): Unit = {

// Overwrites must drop the table, in case there has been a schema update
if(params.overwrite) {
if (params.overwrite) {
val deleteExisting = conn.prepareStatement(s"DROP TABLE IF EXISTS ${params.table}")
deleteExisting.execute()
}
Expand All @@ -114,7 +158,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {

// Execute postActions
params.postActions.foreach(action => {
val actionSql = if(action.contains("%s")) action.format(params.table) else action
val actionSql = if (action.contains("%s")) action.format(params.table) else action
log.info("Executing postAction: " + actionSql)
conn.prepareStatement(actionSql).execute()
})
Expand All @@ -124,19 +168,21 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Serialize temporary data to S3, ready for Redshift COPY
*/
def unloadData(sqlContext: SQLContext, data: DataFrame, tempPath: String): Unit = {
Conversions.datesToTimestamps(sqlContext, data).write.format("com.databricks.spark.avro").save(tempPath)
val enrichedData = Conversions.datesToTimestamps(sqlContext, data) // TODO .extractStringColumnLengths

enrichedData.write.format("com.databricks.spark.avro").save(tempPath)
}

/**
* Write a DataFrame to a Redshift table, using S3 and Avro serialization
*/
def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters) : Unit = {
def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters): Unit = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, new Properties()).apply()

try {
if(params.overwrite && params.useStagingTable) {
if (params.overwrite && params.useStagingTable) {
withStagingTable(conn, params, table => {
val updatedParams = MergedParameters(params.parameters updated ("dbtable", table))
val updatedParams = MergedParameters(params.parameters updated("dbtable", table))
unloadData(sqlContext, data, updatedParams.tempPath)
doRedshiftLoad(conn, data, updatedParams)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.spark.sql.Row
/**
* Unit test for data type conversions
*/
class ConversionsSuite extends FunSuite {
class ConversionsSuite extends MockDatabaseSuite {

val convertRow = Conversions.rowConverter(TestUtils.testSchema)
val convertRow = Conversions.rowConverter(testSchema)

test("Data should be correctly converted") {
val doubleMin = Double.MinValue.toString
Expand All @@ -51,7 +51,7 @@ class ConversionsSuite extends FunSuite {
}

test("Row conversion handles null values") {
val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String]
val emptyRow = List.fill(testSchema.length)(null).toArray[String]
assert(convertRow(emptyRow) == Row(emptyRow: _*))
}
}
Loading