From e477a8ed6d2ea331be357a6fbbb3d55c504971b1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 17 Aug 2018 05:27:32 -0700 Subject: [PATCH] [SPARK-25143][SQL] Support data source name mapping configuration --- .../execution/datasources/DataSource.scala | 8 ++++++- .../sql/sources/ResolvedDataSourceSuite.scala | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b1a10fdb6020..757f7575c853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -609,7 +609,13 @@ object DataSource extends Logging { /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { - val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { + val customBackwardCompatibilityMap = + conf.getAllConfs + .filter(_._1.startsWith("spark.sql.datasource.map")) + .map{ case (k, v) => (k.replaceFirst("^spark.sql.datasource.map.", ""), v) } + val compatibilityMap = backwardCompatibilityMap ++ customBackwardCompatibilityMap + + val provider1 = compatibilityMap.getOrElse(provider, provider) match { case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" => classOf[OrcFileFormat].getCanonicalName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 95460fa70d8f..de4f3085df0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -82,4 +82,28 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { } assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } + + test("support custom mapping for data source names") { + val csv = classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat] + + // Map a new data source name to a built-in data source + withSQLConf("spark.sql.datasource.map.myDatasource" -> csv.getCanonicalName) { + assert(getProvidingClass("myDatasource") === csv) + } + + // Map a existing built-in data source name to new data source + val testDataSource = classOf[TestDataSource] + withSQLConf( + "spark.sql.datasource.map.org.apache.spark.sql.avro" -> testDataSource.getCanonicalName, + "spark.sql.datasource.map.com.databricks.spark.csv" -> testDataSource.getCanonicalName, + "spark.sql.datasource.map.com.databricks.spark.avro" -> testDataSource.getCanonicalName) { + assert(getProvidingClass("org.apache.spark.sql.avro") === testDataSource) + assert(getProvidingClass("com.databricks.spark.csv") === testDataSource) + assert(getProvidingClass("com.databricks.spark.avro") === testDataSource) + } + } +} + +class TestDataSource extends DataSourceRegister { + override def shortName(): String = "test" }