From 2959a9cced747eeaff0be2e4c0866dc3494273c4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 19 Aug 2015 15:01:08 -0700 Subject: [PATCH 1/3] Copy Spark code to minimize reliance on private APIs. --- .../spark/redshift/DefaultSource.scala | 1 - .../spark/redshift/RedshiftJDBCWrapper.scala | 210 ++++++++++++++++++ .../spark/redshift/RedshiftRelation.scala | 1 - .../spark/redshift/RedshiftWriter.scala | 3 +- .../databricks/spark/redshift/package.scala | 2 +- .../spark/sql/jdbc/RedshiftJDBCWrapper.scala | 50 ----- .../spark/redshift/RedshiftSourceSuite.scala | 7 +- 7 files changed, 215 insertions(+), 59 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala delete mode 100644 src/main/scala/org/apache/spark/sql/jdbc/RedshiftJDBCWrapper.scala diff --git a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala index cc67d4fb..8fb490da 100644 --- a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala @@ -19,7 +19,6 @@ package com.databricks.spark.redshift import java.util.Properties import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.{DefaultJDBCWrapper, JDBCWrapper} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala new file mode 100644 index 00000000..f273b3a0 --- /dev/null +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import java.sql.{Connection, DriverManager, ResultSetMetaData, SQLException} +import java.util.Properties + +import scala.util.Try + +import org.apache.spark.Logging +import org.apache.spark.sql.types._ + +/** + * Shim which exposes some JDBC helper functions. + */ +private[redshift] class JDBCWrapper extends Logging { + + def registerDriver(driverClass: String): Unit = { + val classLoader = + Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader) + val className = "org.apache.spark.sql.jdbc.DriverRegistry$" + // scalastyle:off + val driverRegistryClass = Class.forName(className, true, classLoader) + // scalastyle:on + driverRegistryClass.getMethod("register", classOf[String]).invoke(driverClass) + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst + * schema. + * + * @param url - The JDBC url to fetch information from. + * @param table - The table name of the desired table. This may also be a + * SQL query wrapped in parentheses. + * + * @return A StructType giving the table's Catalyst schema. + * @throws SQLException if the table specification is garbage. + * @throws SQLException if the table contains an unsupported type. + */ + def resolveTable(url: String, table: String, properties: Properties): StructType = { + val conn: Connection = DriverManager.getConnection(url, properties) + try { + val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() + try { + val rsmd = rs.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder().putString("name", columnName) + val columnType = getCatalystType(dataType, fieldSize, fieldScale, isSigned) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + return new StructType(fields) + } finally { + rs.close() + } + } finally { + conn.close() + } + + throw new RuntimeException("This line is unreachable.") + } + + /** + * Given a driver string and an url, return a function that loads the + * specified driver string then returns a connection to the JDBC url. + * getConnector is run on the driver code, while the function it returns + * is run on the executor. + * + * @param driver - The class name of the JDBC driver for the given url. + * @param url - The JDBC url to connect to. + * + * @return A function that loads the driver and connects to the url. + */ + def getConnector(driver: String, url: String, properties: Properties): () => Connection = { + () => { + try { + if (driver != null) registerDriver(driver) + } catch { + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) + } + DriverManager.getConnection(url, properties) + } + } + + /** + * Compute the SQL schema string for the given Spark SQL Schema. + */ + def schemaString(schema: StructType): String = { + val sb = new StringBuilder() + schema.fields.foreach { field => { + val name = field.name + val typ: String = field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "SMALLINT" // Redshift does not support the BYTE type. + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + } + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Maps a JDBC type to a Catalyst type. + * + * @param sqlType - A field of java.sql.Types + * @return The Catalyst type corresponding to sqlType. + */ + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { + val answer = sqlType match { + // scalastyle:off + case java.sql.Types.ARRAY => null + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } + case java.sql.Types.BINARY => BinaryType + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks + case java.sql.Types.BLOB => BinaryType + case java.sql.Types.BOOLEAN => BooleanType + case java.sql.Types.CHAR => StringType + case java.sql.Types.CLOB => StringType + case java.sql.Types.DATALINK => null + case java.sql.Types.DATE => DateType + case java.sql.Types.DECIMAL + if precision != 0 || scale != 0 => DecimalType(precision, scale) + case java.sql.Types.DECIMAL => DecimalType(38, 18) // Spark 1.5.0 default + case java.sql.Types.DISTINCT => null + case java.sql.Types.DOUBLE => DoubleType + case java.sql.Types.FLOAT => FloatType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } + case java.sql.Types.JAVA_OBJECT => null + case java.sql.Types.LONGNVARCHAR => StringType + case java.sql.Types.LONGVARBINARY => BinaryType + case java.sql.Types.LONGVARCHAR => StringType + case java.sql.Types.NCHAR => StringType + case java.sql.Types.NCLOB => StringType + case java.sql.Types.NULL => null + case java.sql.Types.NUMERIC + if precision != 0 || scale != 0 => DecimalType(precision, scale) + case java.sql.Types.NUMERIC => DecimalType(38, 18) // Spark 1.5.0 default + case java.sql.Types.NVARCHAR => StringType + case java.sql.Types.OTHER => null + case java.sql.Types.REAL => DoubleType + case java.sql.Types.REF => StringType + case java.sql.Types.ROWID => LongType + case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SQLXML => StringType + case java.sql.Types.STRUCT => StringType + case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.VARBINARY => BinaryType + case java.sql.Types.VARCHAR => StringType + case _ => null + // scalastyle:on + } + + if (answer == null) throw new SQLException("Unsupported type " + sqlType) + answer + } +} + +private[redshift] object DefaultJDBCWrapper extends JDBCWrapper diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 23804fb7..20380aec 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -22,7 +22,6 @@ import com.databricks.spark.redshift.Parameters.MergedParameters import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.jdbc.JDBCWrapper import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 1de3c142..bffd2239 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -24,7 +24,6 @@ import scala.util.Random import com.databricks.spark.redshift.Parameters.MergedParameters import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.{DefaultJDBCWrapper, JDBCWrapper} import org.apache.spark.sql.{DataFrame, SQLContext} /** @@ -36,7 +35,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging { * Generate CREATE TABLE statement for Redshift */ def createTableSql(data: DataFrame, params: MergedParameters): String = { - val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl) + val schemaSql = jdbcWrapper.schemaString(data.schema) val distStyleDef = params.distStyle match { case Some(style) => s"DISTSTYLE $style" case None => "" diff --git a/src/main/scala/com/databricks/spark/redshift/package.scala b/src/main/scala/com/databricks/spark/redshift/package.scala index 32fa4051..129cfb8b 100644 --- a/src/main/scala/com/databricks/spark/redshift/package.scala +++ b/src/main/scala/com/databricks/spark/redshift/package.scala @@ -17,8 +17,8 @@ package com.databricks.spark +import com.databricks.spark.redshift.DefaultJDBCWrapper import org.apache.spark.sql.functions._ -import org.apache.spark.sql.jdbc.DefaultJDBCWrapper import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SQLContext} diff --git a/src/main/scala/org/apache/spark/sql/jdbc/RedshiftJDBCWrapper.scala b/src/main/scala/org/apache/spark/sql/jdbc/RedshiftJDBCWrapper.scala deleted file mode 100644 index eb9c4e98..00000000 --- a/src/main/scala/org/apache/spark/sql/jdbc/RedshiftJDBCWrapper.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2015 TouchType Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.Connection -import java.util.Properties - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.StructType - -/** - * Hack to access some private JDBC SQL functionality - */ -class JDBCWrapper { - def schemaString(dataFrame: DataFrame, url: String): String = { - JDBCWriteDetails.schemaString(dataFrame, url) - } - - def registerDriver(driverClass: String): Unit = { - DriverRegistry.register(driverClass) - } - - def resolveTable(jdbcUrl: String, table: String, properties: Properties): StructType = { - JDBCRDD.resolveTable(jdbcUrl, table, properties) - } - - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - JDBCRDD.getConnector(driver, url, properties) - } - - def tableExists(conn: Connection, table: String): Boolean = { - JdbcUtils.tableExists(conn, table) - } -} - -object DefaultJDBCWrapper extends JDBCWrapper diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index 2f38e00b..a216330b 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -29,7 +29,6 @@ import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.jdbc.JDBCWrapper import org.apache.spark.sql.sources._ import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} @@ -273,7 +272,7 @@ class RedshiftSourceSuite .anyNumberOfTimes() (jdbcWrapper.schemaString _) - .expects(*, params("url")) + .expects(*) .returning("schema") .anyNumberOfTimes() @@ -325,7 +324,7 @@ class RedshiftSourceSuite .anyNumberOfTimes() (jdbcWrapper.schemaString _) - .expects(*, params("url")) + .expects(*) .anyNumberOfTimes() inSequence { @@ -371,7 +370,7 @@ class RedshiftSourceSuite .anyNumberOfTimes() (jdbcWrapper.schemaString _) - .expects(*, defaultParams("url")) + .expects(*) .returning("schema") .anyNumberOfTimes() From e75ee3f9876138a512fefcee46579faeb8f571b6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 19 Aug 2015 15:38:08 -0700 Subject: [PATCH 2/3] Fix two minor bugs and address review comments. --- .../spark/redshift/RedshiftJDBCWrapper.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala index f273b3a0..ef901ade 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -26,14 +26,20 @@ import org.apache.spark.Logging import org.apache.spark.sql.types._ /** - * Shim which exposes some JDBC helper functions. + * Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with + * minor modifications for Redshift-specific features and limitations. */ private[redshift] class JDBCWrapper extends Logging { def registerDriver(driverClass: String): Unit = { + // DriverRegistry.register() is one of the few pieces of private Spark functionality which + // we need to rely on. This class was relocated in Spark 1.5.0, so we need to use reflection + // in order to support both Spark 1.4.x and 1.5.x. + // TODO: once 1.5.0 snapshots are on Maven, update this to switch the class name based on + // SPARK_VERSION. val classLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader) - val className = "org.apache.spark.sql.jdbc.DriverRegistry$" + val className = "org.apache.spark.sql.jdbc.package$DriverRegistry$" // scalastyle:off val driverRegistryClass = Class.forName(className, true, classLoader) // scalastyle:on @@ -122,7 +128,7 @@ private[redshift] class JDBCWrapper extends Logging { case FloatType => "REAL" case ShortType => "INTEGER" case ByteType => "SMALLINT" // Redshift does not support the BYTE type. - case BooleanType => "BIT(1)" + case BooleanType => "BOOLEAN" case StringType => "TEXT" case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" @@ -156,6 +162,7 @@ private[redshift] class JDBCWrapper extends Logging { precision: Int, scale: Int, signed: Boolean): DataType = { + // TODO: cleanup types which are irrelevant for Redshift. val answer = sqlType match { // scalastyle:off case java.sql.Types.ARRAY => null From d04a5577609b5e1bf0c2227e9c0266af51e7d466 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 19 Aug 2015 15:54:10 -0700 Subject: [PATCH 3/3] Fix problem in reflection code --- .../com/databricks/spark/redshift/RedshiftJDBCWrapper.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala index ef901ade..4598994b 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -43,7 +43,9 @@ private[redshift] class JDBCWrapper extends Logging { // scalastyle:off val driverRegistryClass = Class.forName(className, true, classLoader) // scalastyle:on - driverRegistryClass.getMethod("register", classOf[String]).invoke(driverClass) + val registerMethod = driverRegistryClass.getDeclaredMethod("register", classOf[String]) + val companionObject = driverRegistryClass.getDeclaredField("MODULE$").get(null) + registerMethod.invoke(companionObject, driverClass) } /**