From 9ca66d9bc62db3519276cfe5c88d20ccaab69ada Mon Sep 17 00:00:00 2001 From: Rene Treffer Date: Mon, 13 Apr 2015 23:38:59 +0200 Subject: [PATCH 1/4] [SPARK-6888] Export driver quirks Make it possible to (temporary) overwrite the driver quirks. This can be used to overcome problems with specific schemas or to add new jdbc driver support on the fly. --- .../apache/spark/sql/jdbc/DriverQuirks.scala | 60 +++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 1704be7fcbd3..96e635c7cfc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -39,33 +39,70 @@ import java.sql.Types * if `getJDBCType` returns `(null, None)`, the default type handling is used * for the given Catalyst type. */ -private[sql] abstract class DriverQuirks { +abstract class DriverQuirks { + def canHandle(url : String): Boolean def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType def getJDBCType(dt: DataType): (String, Option[Int]) } -private[sql] object DriverQuirks { +object DriverQuirks { + + private var quirks = List[DriverQuirks]() + + def registerQuirks(quirk: DriverQuirks) { + quirks = quirk :: quirks + } + + def unregisterQuirks(quirk : DriverQuirks) { + quirks = quirks.filterNot(_ == quirk) + } + + registerQuirks(new MySQLQuirks()) + registerQuirks(new PostgresQuirks()) + /** * Fetch the DriverQuirks class corresponding to a given database url. */ def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { - new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { - new PostgresQuirks() - } else { - new NoQuirks() + val matchingQuirks = quirks.filter(_.canHandle(url)) + matchingQuirks.length match { + case 0 => new NoQuirks() + case 1 => matchingQuirks.head + case _ => new AggregatedQuirks(matchingQuirks) } } } -private[sql] class NoQuirks extends DriverQuirks { +class AggregatedQuirks(quirks: List[DriverQuirks]) extends DriverQuirks { + def canHandle(url : String): Boolean = + quirks.foldLeft(true)((l,r) => l && r.canHandle(url)) + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder) : DataType = + quirks.foldLeft(null.asInstanceOf[DataType])((l,r) => + if (l != null) { + l + } else { + r.getCatalystType(sqlType, typeName, size, md) + } + ) + def getJDBCType(dt: DataType): (String, Option[Int]) = + quirks.foldLeft(null.asInstanceOf[(String, Option[Int])])((l,r) => + if (l != null) { + l + } else { + r.getJDBCType(dt) + } + ) +} + +class NoQuirks extends DriverQuirks { + def canHandle(url : String): Boolean = true def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = null def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) } -private[sql] class PostgresQuirks extends DriverQuirks { +class PostgresQuirks extends DriverQuirks { + def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { BinaryType @@ -84,7 +121,8 @@ private[sql] class PostgresQuirks extends DriverQuirks { } } -private[sql] class MySQLQuirks extends DriverQuirks { +class MySQLQuirks extends DriverQuirks { + def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as From 8e0b5e390bddbc0d1ff3e423ce827b2c75bc5b3f Mon Sep 17 00:00:00 2001 From: Rene Treffer Date: Wed, 15 Apr 2015 23:11:55 +0200 Subject: [PATCH 2/4] [SPARK-6888] Add tests for custom driver quirks (register/unregister/type mapping) --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 592ed4b23b7d..299a527b1570 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -22,6 +22,7 @@ import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} import org.apache.spark.sql.test._ +import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException import org.scalatest.{FunSuite, BeforeAndAfter} import TestSQLContext._ @@ -34,6 +35,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testDriverQuirks = new DriverQuirks { + def canHandle(url: String) = url.startsWith("jdbc:h2") + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + StringType + } + def getJDBCType(dt: org.apache.spark.sql.types.DataType): (String, Option[Int]) = (null, None) + } + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -282,4 +291,29 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) } } + + test("Remap types via DriverQuirks") { + DriverQuirks.registerQuirks(testDriverQuirks) + val df = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE") + assert(df.schema.filter( + _.dataType != org.apache.spark.sql.types.StringType + ).isEmpty) + val rows = df.collect() + assert(rows(0).get(0).isInstanceOf[String]) + assert(rows(0).get(1).isInstanceOf[String]) + DriverQuirks.unregisterQuirks(testDriverQuirks) + } + + test("Default quirks registration") { + assert(DriverQuirks.get("jdbc:mysql://127.0.0.1/db").isInstanceOf[MySQLQuirks]) + assert(DriverQuirks.get("jdbc:postgresql://127.0.0.1/db").isInstanceOf[PostgresQuirks]) + assert(DriverQuirks.get("test.invalid").isInstanceOf[NoQuirks]) + } + + test("Quirk unregister") { + DriverQuirks.registerQuirks(testDriverQuirks) + DriverQuirks.unregisterQuirks(testDriverQuirks) + assert(DriverQuirks.get(urlWithUserAndPass).isInstanceOf[NoQuirks]) + } + } From 7f234842f2eb12dabfa5cd6b9145f678a26a0b0a Mon Sep 17 00:00:00 2001 From: Rene Treffer Date: Wed, 15 Apr 2015 23:12:43 +0200 Subject: [PATCH 3/4] [SPARK-6888] Fix driver quirks handling --- .../apache/spark/sql/jdbc/DriverQuirks.scala | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 96e635c7cfc8..65d7a3b5e087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -49,11 +49,11 @@ object DriverQuirks { private var quirks = List[DriverQuirks]() - def registerQuirks(quirk: DriverQuirks) { + def registerQuirks(quirk: DriverQuirks) : Unit = { quirks = quirk :: quirks } - def unregisterQuirks(quirk : DriverQuirks) { + def unregisterQuirks(quirk : DriverQuirks) : Unit = { quirks = quirks.filterNot(_ == quirk) } @@ -74,24 +74,22 @@ object DriverQuirks { } class AggregatedQuirks(quirks: List[DriverQuirks]) extends DriverQuirks { + + require(!quirks.isEmpty) + def canHandle(url : String): Boolean = - quirks.foldLeft(true)((l,r) => l && r.canHandle(url)) - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder) : DataType = - quirks.foldLeft(null.asInstanceOf[DataType])((l,r) => - if (l != null) { - l - } else { - r.getCatalystType(sqlType, typeName, size, md) - } - ) + quirks.map(_.canHandle(url)).reduce(_ && _) + + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = + quirks.map(_.getCatalystType(sqlType, typeName, size, md)).collectFirst { + case dataType if dataType != null => dataType + }.orNull + def getJDBCType(dt: DataType): (String, Option[Int]) = - quirks.foldLeft(null.asInstanceOf[(String, Option[Int])])((l,r) => - if (l != null) { - l - } else { - r.getJDBCType(dt) - } - ) + quirks.map(_.getJDBCType(dt)).collectFirst { + case t @ (typeName,sqlType) if typeName != null || sqlType.isDefined => t + }.getOrElse((null, None)) + } class NoQuirks extends DriverQuirks { From 22d65cac9bb22a9cdda5019042acca0c66e46270 Mon Sep 17 00:00:00 2001 From: Rene Treffer Date: Thu, 16 Apr 2015 23:49:19 +0200 Subject: [PATCH 4/4] [SPARK-6888] Rename driver quirks to jdbc dialect and add tests + scaladoc --- .../apache/spark/sql/jdbc/DriverQuirks.scala | 135 ------------ .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 8 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 201 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/jdbc.scala | 8 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 45 ++-- 5 files changed, 240 insertions(+), 157 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala deleted file mode 100644 index 65d7a3b5e087..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.types._ - -import java.sql.Types - - -/** - * Encapsulates workarounds for the extensions, quirks, and bugs in various - * databases. Lots of databases define types that aren't explicitly supported - * by the JDBC spec. Some JDBC drivers also report inaccurate - * information---for instance, BIT(n>1) being reported as a BIT type is quite - * common, even though BIT in JDBC is meant for single-bit values. Also, there - * does not appear to be a standard name for an unbounded string or binary - * type; we use BLOB and CLOB by default but override with database-specific - * alternatives when these are absent or do not behave correctly. - * - * Currently, the only thing DriverQuirks does is handle type mapping. - * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` - * is used when writing to a JDBC table. If `getCatalystType` returns `null`, - * the default type handling is used for the given JDBC type. Similarly, - * if `getJDBCType` returns `(null, None)`, the default type handling is used - * for the given Catalyst type. - */ -abstract class DriverQuirks { - def canHandle(url : String): Boolean - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType - def getJDBCType(dt: DataType): (String, Option[Int]) -} - -object DriverQuirks { - - private var quirks = List[DriverQuirks]() - - def registerQuirks(quirk: DriverQuirks) : Unit = { - quirks = quirk :: quirks - } - - def unregisterQuirks(quirk : DriverQuirks) : Unit = { - quirks = quirks.filterNot(_ == quirk) - } - - registerQuirks(new MySQLQuirks()) - registerQuirks(new PostgresQuirks()) - - /** - * Fetch the DriverQuirks class corresponding to a given database url. - */ - def get(url: String): DriverQuirks = { - val matchingQuirks = quirks.filter(_.canHandle(url)) - matchingQuirks.length match { - case 0 => new NoQuirks() - case 1 => matchingQuirks.head - case _ => new AggregatedQuirks(matchingQuirks) - } - } -} - -class AggregatedQuirks(quirks: List[DriverQuirks]) extends DriverQuirks { - - require(!quirks.isEmpty) - - def canHandle(url : String): Boolean = - quirks.map(_.canHandle(url)).reduce(_ && _) - - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - quirks.map(_.getCatalystType(sqlType, typeName, size, md)).collectFirst { - case dataType if dataType != null => dataType - }.orNull - - def getJDBCType(dt: DataType): (String, Option[Int]) = - quirks.map(_.getJDBCType(dt)).collectFirst { - case t @ (typeName,sqlType) if typeName != null || sqlType.isDefined => t - }.getOrElse((null, None)) - -} - -class NoQuirks extends DriverQuirks { - def canHandle(url : String): Boolean = true - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - null - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} - -class PostgresQuirks extends DriverQuirks { - def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - BinaryType - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - StringType - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - StringType - } else null - } - - def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { - case StringType => ("TEXT", Some(java.sql.Types.CHAR)) - case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) - case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) - case _ => (null, None) - } -} - -class MySQLQuirks extends DriverQuirks { - def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - LongType - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - BooleanType - } else null - } - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268b..0f1674487725 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources._ private[sql] object JDBCRDD extends Logging { /** * Maps a JDBC type to a Catalyst type. This function is called only when - * the DriverQuirks class corresponding to your database driver returns null. + * the JdbcDialect class corresponding to your database driver returns null. * * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. @@ -40,7 +40,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.ARRAY => null case java.sql.Types.BIGINT => LongType case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. + case java.sql.Types.BIT => BooleanType // Per JDBC; JdbcDialect handles quirky drivers. case java.sql.Types.BLOB => BinaryType case java.sql.Types.BOOLEAN => BooleanType case java.sql.Types.CHAR => StringType @@ -92,7 +92,7 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table contains an unsupported type. */ def resolveTable(url: String, table: String, properties: Properties): StructType = { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() @@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging { val fieldSize = rsmd.getPrecision(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) - var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) + var columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata) if (columnType == null) columnType = getCatalystType(dataType) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala new file mode 100644 index 000000000000..1ca81ad5d580 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.DeveloperApi + +import java.sql.Types + + +/** + * :: DeveloperApi :: + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect of a certain database or jdbc driver. + * Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing done by the dialect is type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +@DeveloperApi +abstract class JdbcDialect { + /** + * Check if this dialect instance can handle a certain jdbc url. + * @param url the jdbc url. + * @return True if the dialect can be applied on the given jdbc url. + * @throws NullPointerException if the url is null. + */ + def canHandle(url : String): Boolean + + /** + * Get the custom datatype mapping for the given jdbc meta information. + * @param sqlType The sql type (see java.sql.Types) + * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") + * @param size The size of the type. + * @param md Result metadata associated with this type. + * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) + * or null if the default type mapping should be used. + */ + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = null + + /** + * Retrieve the jdbc / sql type for a give datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return A tuple of sql type name and sql type, or {{{(null, None)}}} for no change. + */ + def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) +} + +/** + * :: DeveloperApi :: + * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]]. + * + * If multiple matching dialects are registered then all matching ones will be + * tried in reverse order. A user-added dialect will thus be applied first, + * overwriting the defaults. + * + * Note that all new dialects are applied to new jdbc DataFrames only. Make + * sure to register your dialects first. + */ +@DeveloperApi +object JdbcDialects { + + private var dialects = List[JdbcDialect]() + + /** + * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. + * Readding an existing dialect will cause a move-to-front. + * @param dialect The new dialect. + */ + def registerDialect(dialect: JdbcDialect) : Unit = { + dialects = dialect :: dialects.filterNot(_ == dialect) + } + + /** + * Unregister a dialect. Does nothing if the dialect is not registered. + * @param dialect The jdbc dialect. + */ + def unregisterDialect(dialect : JdbcDialect) : Unit = { + dialects = dialects.filterNot(_ == dialect) + } + + registerDialect(MySQLDialect) + registerDialect(PostgresDialect) + + /** + * Fetch the JdbcDialect class corresponding to a given database url. + */ + private[sql] def get(url: String): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(url)) + matchingDialects.length match { + case 0 => NoopDialect + case 1 => matchingDialects.head + case _ => new AggregatedDialect(matchingDialects) + } + } +} + +/** + * :: DeveloperApi :: + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * @param dialects List of dialects. + */ +@DeveloperApi +class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(!dialects.isEmpty) + + def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = + dialects.map(_.getCatalystType(sqlType, typeName, size, md)).collectFirst { + case dataType if dataType != null => dataType + }.orNull + + override def getJDBCType(dt: DataType): (String, Option[Int]) = + dialects.map(_.getJDBCType(dt)).collectFirst { + case t @ (typeName,sqlType) if typeName != null || sqlType.isDefined => t + }.getOrElse((null, None)) + +} + +/** + * :: DeveloperApi :: + * NOOP dialect object, always returning the neutral element. + */ +@DeveloperApi +case object NoopDialect extends JdbcDialect { + def canHandle(url : String): Boolean = true +} + +/** + * :: DeveloperApi :: + * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. + */ +@DeveloperApi +case object PostgresDialect extends JdbcDialect { + def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + BinaryType + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + StringType + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + StringType + } else null + } + + override def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { + case StringType => ("TEXT", Some(java.sql.Types.CHAR)) + case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) + case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) + case _ => (null, None) + } +} + +/** + * :: DeveloperApi :: + * Default mysql dialect to read bit/bitsets correctly. + */ +@DeveloperApi +case object MySQLDialect extends JdbcDialect { + def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + LongType + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + BooleanType + } else null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7..a7e6aa99e361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -123,10 +123,10 @@ package object jdbc { */ def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - var typ: String = quirks.getJDBCType(field.dataType)._1 + var typ: String = dialect.getJDBCType(field.dataType)._1 if (typ == null) typ = field.dataType match { case IntegerType => "INTEGER" case LongType => "BIGINT" @@ -152,9 +152,9 @@ package object jdbc { * Saves the RDD to the database in a single transaction. */ def saveTable(df: DataFrame, url: String, table: String) { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) var nullTypes: Array[Int] = df.schema.fields.map(field => { - var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 + var nullType: Option[Int] = dialect.getJDBCType(field.dataType)._2 if (nullType.isEmpty) { field.dataType match { case IntegerType => java.sql.Types.INTEGER diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 299a527b1570..4fa810316e64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -35,12 +35,11 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) - val testDriverQuirks = new DriverQuirks { + val testH2Dialect = new JdbcDialect { def canHandle(url: String) = url.startsWith("jdbc:h2") - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { StringType } - def getJDBCType(dt: org.apache.spark.sql.types.DataType): (String, Option[Int]) = (null, None) } before { @@ -292,8 +291,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } } - test("Remap types via DriverQuirks") { - DriverQuirks.registerQuirks(testDriverQuirks) + test("Remap types via JdbcDialects") { + JdbcDialects.registerDialect(testH2Dialect) val df = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE") assert(df.schema.filter( _.dataType != org.apache.spark.sql.types.StringType @@ -301,19 +300,37 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) assert(rows(0).get(1).isInstanceOf[String]) - DriverQuirks.unregisterQuirks(testDriverQuirks) + JdbcDialects.unregisterDialect(testH2Dialect) } - test("Default quirks registration") { - assert(DriverQuirks.get("jdbc:mysql://127.0.0.1/db").isInstanceOf[MySQLQuirks]) - assert(DriverQuirks.get("jdbc:postgresql://127.0.0.1/db").isInstanceOf[PostgresQuirks]) - assert(DriverQuirks.get("test.invalid").isInstanceOf[NoQuirks]) + test("Default jdbc dialect registration") { + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("test.invalid") == NoopDialect) } - test("Quirk unregister") { - DriverQuirks.registerQuirks(testDriverQuirks) - DriverQuirks.unregisterQuirks(testDriverQuirks) - assert(DriverQuirks.get(urlWithUserAndPass).isInstanceOf[NoQuirks]) + test("Dialect unregister") { + JdbcDialects.registerDialect(testH2Dialect) + JdbcDialects.unregisterDialect(testH2Dialect) + assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + } + + test("Aggregated dialects") { + val agg = new AggregatedDialect(List(new JdbcDialect { + def canHandle(url: String) = url.startsWith("jdbc:h2:") + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + if (sqlType % 2 == 0) { + LongType + } else { + null + } + } + }, testH2Dialect)) + assert(agg.canHandle("jdbc:h2:xxx")) + assert(!agg.canHandle("jdbc:h2")) + assert(agg.getCatalystType(0,"",1,null) == LongType) + assert(agg.getCatalystType(1,"",1,null) == StringType) } } +