From d92c9b6828946ab89c7d91f3d4707b443ae62f62 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 Jan 2016 11:20:35 -0800 Subject: [PATCH] Handle DriverWrapper when scanning registered drivers --- .../spark/redshift/RedshiftJDBCWrapper.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala index c0dc9dad..6b3d2572 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -201,12 +201,23 @@ private[redshift] class JDBCWrapper { val subprotocol = url.stripPrefix("jdbc:").split(":")(0) val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass) registerDriver(driverClass) + val driverWrapperClass: Class[_] = if (SPARK_VERSION.startsWith("1.4")) { + Utils.classForName("org.apache.spark.sql.jdbc.package$DriverWrapper") + } else { // Spark 1.5.0+ + Utils.classForName("org.apache.spark.sql.execution.datasources.jdbc.DriverWrapper") + } + def getWrapped(d: Driver): Driver = { + require(driverWrapperClass.isAssignableFrom(d.getClass)) + driverWrapperClass.getDeclaredMethod("wrapped").invoke(d).asInstanceOf[Driver] + } // Note that we purposely don't call DriverManager.getConnection() here: we want to ensure // that an explicitly-specified user-provided driver class can take precedence over the default // class, but DriverManager.getConnection() might return a according to a different precedence. // At the same time, we don't want to create a driver-per-connection, so we use the // DriverManager's driver instances to handle that singleton logic for us. val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d if driverWrapperClass.isAssignableFrom(d.getClass) + && getWrapped(d).getClass.getCanonicalName == driverClass => d case d if d.getClass.getCanonicalName == driverClass => d }.getOrElse { throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass")