Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package com.databricks.spark.redshift
import java.util.Properties

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.{DefaultJDBCWrapper, JDBCWrapper}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
Expand Down
219 changes: 219 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import java.sql.{Connection, DriverManager, ResultSetMetaData, SQLException}
import java.util.Properties

import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.types._

/**
* Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with
* minor modifications for Redshift-specific features and limitations.
*/
private[redshift] class JDBCWrapper extends Logging {

def registerDriver(driverClass: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about what is going on here and why we are using reflection.

// DriverRegistry.register() is one of the few pieces of private Spark functionality which
// we need to rely on. This class was relocated in Spark 1.5.0, so we need to use reflection
// in order to support both Spark 1.4.x and 1.5.x.
// TODO: once 1.5.0 snapshots are on Maven, update this to switch the class name based on
// SPARK_VERSION.
val classLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader)
val className = "org.apache.spark.sql.jdbc.package$DriverRegistry$"
// scalastyle:off
val driverRegistryClass = Class.forName(className, true, classLoader)
// scalastyle:on
val registerMethod = driverRegistryClass.getDeclaredMethod("register", classOf[String])
val companionObject = driverRegistryClass.getDeclaredField("MODULE$").get(null)
registerMethod.invoke(companionObject, driverClass)
}

/**
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
*
* @param url - The JDBC url to fetch information from.
* @param table - The table name of the desired table. This may also be a
* SQL query wrapped in parentheses.
*
* @return A StructType giving the table's Catalyst schema.
* @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try {
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = rsmd.isSigned(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned)
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
return new StructType(fields)
} finally {
rs.close()
}
} finally {
conn.close()
}

throw new RuntimeException("This line is unreachable.")
}

/**
* Given a driver string and an url, return a function that loads the
* specified driver string then returns a connection to the JDBC url.
* getConnector is run on the driver code, while the function it returns
* is run on the executor.
*
* @param driver - The class name of the JDBC driver for the given url.
* @param url - The JDBC url to connect to.
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) registerDriver(driver)
} catch {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
}

/**
* Compute the SQL schema string for the given Spark SQL Schema.
*/
def schemaString(schema: StructType): String = {
val sb = new StringBuilder()
schema.fields.foreach { field => {
val name = field.name
val typ: String = field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "SMALLINT" // Redshift does not support the BYTE type.
case BooleanType => "BOOLEAN"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
}
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
if (sb.length < 2) "" else sb.substring(2)
}

/**
* Returns true if the table already exists in the JDBC database.
*/
def tableExists(conn: Connection, table: String): Boolean = {
// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
// SQL database systems, considering "table" could also include the database name.
Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
}

/**
* Maps a JDBC type to a Catalyst type.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
*/
private def getCatalystType(
sqlType: Int,
precision: Int,
scale: Int,
signed: Boolean): DataType = {
// TODO: cleanup types which are irrelevant for Redshift.
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType(precision, scale)
case java.sql.Types.DECIMAL => DecimalType(38, 18) // Spark 1.5.0 default
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
case java.sql.Types.JAVA_OBJECT => null
case java.sql.Types.LONGNVARCHAR => StringType
case java.sql.Types.LONGVARBINARY => BinaryType
case java.sql.Types.LONGVARCHAR => StringType
case java.sql.Types.NCHAR => StringType
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
if precision != 0 || scale != 0 => DecimalType(precision, scale)
case java.sql.Types.NUMERIC => DecimalType(38, 18) // Spark 1.5.0 default
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ => null
// scalastyle:on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a TODO to clean up the ones that are irrelevant for redshift.

}

if (answer == null) throw new SQLException("Unsupported type " + sqlType)
answer
}
}

private[redshift] object DefaultJDBCWrapper extends JDBCWrapper
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import com.databricks.spark.redshift.Parameters.MergedParameters

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.jdbc.JDBCWrapper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.util.Random
import com.databricks.spark.redshift.Parameters.MergedParameters

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.{DefaultJDBCWrapper, JDBCWrapper}
import org.apache.spark.sql.{DataFrame, SQLContext}

/**
Expand All @@ -36,7 +35,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Generate CREATE TABLE statement for Redshift
*/
def createTableSql(data: DataFrame, params: MergedParameters): String = {
val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl)
val schemaSql = jdbcWrapper.schemaString(data.schema)
val distStyleDef = params.distStyle match {
case Some(style) => s"DISTSTYLE $style"
case None => ""
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/databricks/spark/redshift/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package com.databricks.spark

import com.databricks.spark.redshift.DefaultJDBCWrapper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.jdbc.DefaultJDBCWrapper
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

Expand Down
50 changes: 0 additions & 50 deletions src/main/scala/org/apache/spark/sql/jdbc/RedshiftJDBCWrapper.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, Matchers}

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.jdbc.JDBCWrapper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}

Expand Down Expand Up @@ -273,7 +272,7 @@ class RedshiftSourceSuite
.anyNumberOfTimes()

(jdbcWrapper.schemaString _)
.expects(*, params("url"))
.expects(*)
.returning("schema")
.anyNumberOfTimes()

Expand Down Expand Up @@ -325,7 +324,7 @@ class RedshiftSourceSuite
.anyNumberOfTimes()

(jdbcWrapper.schemaString _)
.expects(*, params("url"))
.expects(*)
.anyNumberOfTimes()

inSequence {
Expand Down Expand Up @@ -371,7 +370,7 @@ class RedshiftSourceSuite
.anyNumberOfTimes()

(jdbcWrapper.schemaString _)
.expects(*, defaultParams("url"))
.expects(*)
.returning("schema")
.anyNumberOfTimes()

Expand Down