From 356c55eb5d50db1ab0fe9f15285cf31d993fad8a Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 1 Dec 2017 20:17:58 +0800 Subject: [PATCH 01/14] propagate session configs to data source options. --- .../spark/sql/sources/v2/ConfigSupport.java | 37 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 39 +++++++++++++++++-- .../sql/sources/v2/DataSourceV2Suite.scala | 26 ++++++++++++- 3 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java new file mode 100644 index 0000000000000..894bc9e1d3e67 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; + +import java.util.List; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * propagate session configs with chosen key-prefixes to the particular data source. + */ +@InterfaceStability.Evolving +public interface ConfigSupport { + + /** + * Create a list of key-prefixes, all session configs that match at least one of the prefixes + * will be propagated to the data source options. + */ + List getConfigPrefixes(); +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 17966eecfc051..77af21b59c6ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import scala.collection.immutable import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -33,7 +34,8 @@ import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -169,6 +171,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` } + import DataFrameReader._ + /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. * Only works if the source is a HadoopFsRelationProvider. @@ -184,9 +188,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val options = new DataSourceV2Options(extraOptions.asJava) + val dataSource = cls.newInstance() + val options = dataSource match { + case cs: ConfigSupport => + val confs = withSessionConfig(cs, sparkSession.sessionState.conf) + new DataSourceV2Options((confs ++ extraOptions).asJava) + case _ => + new DataSourceV2Options(extraOptions.asJava) + } - val reader = (cls.newInstance(), userSpecifiedSchema) match { + val reader = (dataSource, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -732,3 +743,25 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { private val extraOptions = new scala.collection.mutable.HashMap[String, String] } + +private[sql] object DataFrameReader { + + /** + * Helper method to filter session configs with config key that matches at least one of the given + * prefixes. + * + * @param cs the config key-prefixes that should be filtered. + * @param conf the session conf + * @return an immutable map that contains all the session configs that should be propagated to + * the data source. + */ + def withSessionConfig( + cs: ConfigSupport, + conf: SQLConf): immutable.Map[String, String] = { + val prefixes = cs.getConfigPrefixes + require(prefixes != null, "The config key-prefixes cann't be null.") + conf.getAllConfs.filterKeys { confKey => + prefixes.asScala.exists(confKey.startsWith(_)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1f..2d6e7fc1fbbe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.sources.v2 import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ - import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrameReader, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext @@ -43,6 +43,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("simple implementation with config support") { + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false", + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key -> "true", + SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", + SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { + val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport] + val confs = DataFrameReader.withSessionConfig(cs, SQLConf.get) + assert(confs.size == 3) + assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 2) + assert(confs.keySet.filter( + _.startsWith("spark.sql.sources.parallelPartitionDiscovery.threshold")).size == 1) + assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + } + } + test("advanced implementation") { Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -179,7 +194,14 @@ class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader override def close(): Unit = {} } +class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport { + override def getConfigPrefixes: JList[String] = { + java.util.Arrays.asList( + "spark.sql.parquet", + "spark.sql.sources.parallelPartitionDiscovery.threshold") + } +} class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { From b076a69b0180825d1726019ef6213d1be2324c26 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 1 Dec 2017 21:36:22 +0800 Subject: [PATCH 02/14] fix scalastyle --- .../org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2d6e7fc1fbbe7..1af895372df40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources.v2 import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ + import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrameReader, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow From eaa6cae1fd61f215555238acdcfc7477d559cc47 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 1 Dec 2017 22:25:08 +0800 Subject: [PATCH 03/14] style fix --- .../java/org/apache/spark/sql/sources/v2/ConfigSupport.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index 894bc9e1d3e67..c4dbd1a28391f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -34,4 +34,4 @@ public interface ConfigSupport { * will be propagated to the data source options. */ List getConfigPrefixes(); -} \ No newline at end of file +} From ec5723c194474c85af0c4bd6265c6f7b0781881e Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 4 Dec 2017 23:44:13 +0800 Subject: [PATCH 04/14] move around function withSessionConfig. --- .../apache/spark/sql/DataFrameReader.scala | 27 +---------- .../v2/DataSourceV2ConfigSupport.scala | 46 +++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 5 +- 3 files changed, 51 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 77af21b59c6ad..83fdf9b3f2739 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.util.{Locale, Properties} import scala.collection.JavaConverters._ -import scala.collection.immutable import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -33,8 +32,8 @@ import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ConfigSupport import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -171,7 +170,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` } - import DataFrameReader._ + import DataSourceV2ConfigSupport._ /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. @@ -743,25 +742,3 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { private val extraOptions = new scala.collection.mutable.HashMap[String, String] } - -private[sql] object DataFrameReader { - - /** - * Helper method to filter session configs with config key that matches at least one of the given - * prefixes. - * - * @param cs the config key-prefixes that should be filtered. - * @param conf the session conf - * @return an immutable map that contains all the session configs that should be propagated to - * the data source. - */ - def withSessionConfig( - cs: ConfigSupport, - conf: SQLConf): immutable.Map[String, String] = { - val prefixes = cs.getConfigPrefixes - require(prefixes != null, "The config key-prefixes cann't be null.") - conf.getAllConfs.filterKeys { confKey => - prefixes.asScala.exists(confKey.startsWith(_)) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala new file mode 100644 index 0000000000000..4d9e06f6f4578 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ +import scala.collection.immutable + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.ConfigSupport + +private[sql] object DataSourceV2ConfigSupport { + + /** + * Helper method to filter session configs with config key that matches at least one of the given + * prefixes. + * + * @param cs the config key-prefixes that should be filtered. + * @param conf the session conf + * @return an immutable map that contains all the session configs that should be propagated to + * the data source. + */ + def withSessionConfig( + cs: ConfigSupport, + conf: SQLConf): immutable.Map[String, String] = { + val prefixes = cs.getConfigPrefixes + require(prefixes != null, "The config key-prefixes cann't be null.") + conf.getAllConfs.filterKeys { confKey => + prefixes.asScala.exists(confKey.startsWith(_)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1af895372df40..5ddf41c38de09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -22,8 +22,9 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrameReader, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ConfigSupport import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -50,7 +51,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport] - val confs = DataFrameReader.withSessionConfig(cs, SQLConf.get) + val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, SQLConf.get) assert(confs.size == 3) assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 2) assert(confs.keySet.filter( From 84df37e60d308a46efe30c391f2dcdb01bb4e4e9 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 6 Dec 2017 22:40:38 +0800 Subject: [PATCH 05/14] update ConfigSupport --- .../spark/sql/sources/v2/ConfigSupport.java | 19 ++++++++++++- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../v2/DataSourceV2ConfigSupport.scala | 28 ++++++++++++++++--- .../sql/sources/v2/DataSourceV2Suite.scala | 17 +++++++---- 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index c4dbd1a28391f..cf221997455ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -18,9 +18,9 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import java.util.List; +import java.util.Map; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to @@ -32,6 +32,23 @@ public interface ConfigSupport { /** * Create a list of key-prefixes, all session configs that match at least one of the prefixes * will be propagated to the data source options. + * If the returned list is empty, no session config will be propagated. */ List getConfigPrefixes(); + + /** + * Create a mapping from session config names to data source option names. If a propagated + * session config's key doesn't exist in this mapping, the "spark.sql.${source}" prefix will + * be trimmed. For example, if the data source name is "parquet", perform the following config + * key mapping by default: + * "spark.sql.parquet.int96AsTimestamp" -> "int96AsTimestamp", + * "spark.sql.parquet.compression.codec" -> "compression.codec", + * "spark.sql.columnNameOfCorruptRecord" -> "columnNameOfCorruptRecord". + * + * If the mapping is specified, for example, the returned map contains an entry + * ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config + * "spark.sql.columnNameOfCorruptRecord" will be converted to "colNameCorrupt" in + * [[DataSourceV2Options]]. + */ + Map getConfigMapping(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 83fdf9b3f2739..59204dc68b96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -190,7 +190,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val dataSource = cls.newInstance() val options = dataSource match { case cs: ConfigSupport => - val confs = withSessionConfig(cs, sparkSession.sessionState.conf) + val confs = withSessionConfig(cs, source, sparkSession.sessionState.conf) new DataSourceV2Options((confs ++ extraOptions).asJava) case _ => new DataSourceV2Options(extraOptions.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala index 4d9e06f6f4578..96610c172d4d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala @@ -17,30 +17,50 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.regex.Pattern + import scala.collection.JavaConverters._ import scala.collection.immutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.ConfigSupport -private[sql] object DataSourceV2ConfigSupport { +private[sql] object DataSourceV2ConfigSupport extends Logging { /** - * Helper method to filter session configs with config key that matches at least one of the given - * prefixes. + * Helper method to propagate session configs with config key that matches at least one of the + * given prefixes to the corresponding data source options. * - * @param cs the config key-prefixes that should be filtered. + * @param cs the session config propagate help class + * @param source the data source format * @param conf the session conf * @return an immutable map that contains all the session configs that should be propagated to * the data source. */ def withSessionConfig( cs: ConfigSupport, + source: String, conf: SQLConf): immutable.Map[String, String] = { val prefixes = cs.getConfigPrefixes require(prefixes != null, "The config key-prefixes cann't be null.") + val mapping = cs.getConfigMapping.asScala + + val pattern = Pattern.compile(s"spark\\.sql(\\.$source)?\\.(.*)") conf.getAllConfs.filterKeys { confKey => prefixes.asScala.exists(confKey.startsWith(_)) + }.map{ entry => + val newKey = mapping.get(entry._1).getOrElse { + val m = pattern.matcher(entry._1) + if (m.matches()) { + m.group(2) + } else { + // Unable to recognize the session config key. + logWarning(s"Unrecognizable session config name ${entry._1}.") + entry._1 + } + } + (newKey, entry._2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 5ddf41c38de09..eaf44ba601e98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2 +import java.util import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ @@ -47,16 +48,16 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { test("simple implementation with config support") { withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false", - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key -> "true", + SQLConf.PARQUET_COMPRESSION.key -> "uncompressed", SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport] - val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, SQLConf.get) + val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get) assert(confs.size == 3) - assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 2) - assert(confs.keySet.filter( - _.startsWith("spark.sql.sources.parallelPartitionDiscovery.threshold")).size == 1) + assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + assert(confs.keySet.contains("compressionCodec")) + assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold")) } } @@ -203,6 +204,12 @@ class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport { "spark.sql.parquet", "spark.sql.sources.parallelPartitionDiscovery.threshold") } + + override def getConfigMapping: util.Map[String, String] = { + val configMap = new util.HashMap[String, String]() + configMap.put("spark.sql.parquet.compression.codec", "compressionCodec") + configMap + } } class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { From 0dd7f2ef314da480f32f30878bff3f1b5942aa03 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 7 Dec 2017 00:07:46 +0800 Subject: [PATCH 06/14] add method getValidOptions --- .../spark/sql/sources/v2/ConfigSupport.java | 8 ++++++ .../v2/DataSourceV2ConfigSupport.scala | 23 ++++++++++++++-- .../sql/sources/v2/DataSourceV2Suite.scala | 27 +++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index cf221997455ab..cdc334bce273d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -51,4 +51,12 @@ public interface ConfigSupport { * [[DataSourceV2Options]]. */ Map getConfigMapping(); + + /** + * Create a list of valid data source option names. When the list is specified, a session + * config will NOT be propagated if its corresponding option name is not in the list. + * + * If the returned list is empty, don't check the option names. + */ + List getValidOptions(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala index 96610c172d4d4..cbea1a1c0f586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala @@ -45,11 +45,14 @@ private[sql] object DataSourceV2ConfigSupport extends Logging { val prefixes = cs.getConfigPrefixes require(prefixes != null, "The config key-prefixes cann't be null.") val mapping = cs.getConfigMapping.asScala + val validOptions = cs.getValidOptions + require(validOptions != null, "The valid options list cann't be null.") val pattern = Pattern.compile(s"spark\\.sql(\\.$source)?\\.(.*)") - conf.getAllConfs.filterKeys { confKey => + val filteredConfigs = conf.getAllConfs.filterKeys { confKey => prefixes.asScala.exists(confKey.startsWith(_)) - }.map{ entry => + } + val convertedConfigs = filteredConfigs.map{ entry => val newKey = mapping.get(entry._1).getOrElse { val m = pattern.matcher(entry._1) if (m.matches()) { @@ -62,5 +65,21 @@ private[sql] object DataSourceV2ConfigSupport extends Logging { } (newKey, entry._2) } + if (validOptions.size == 0) { + convertedConfigs + } else { + // Check whether all the valid options are propagated. + validOptions.asScala.foreach { optionName => + if (!convertedConfigs.keySet.contains(optionName)) { + logWarning(s"Data source option '$optionName' is required, but not propagated from " + + "session config, please check the config settings.") + } + } + + // Filter the valid options. + convertedConfigs.filterKeys { optionName => + validOptions.contains(optionName) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index eaf44ba601e98..a7e82fecc5f45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -61,6 +61,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("config support with validOptions") { + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false", + SQLConf.PARQUET_COMPRESSION.key -> "uncompressed", + SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", + SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { + val cs = classOf[DataSourceV2WithValidOptions].newInstance().asInstanceOf[ConfigSupport] + val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get) + assert(confs.size == 2) + assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0) + assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + assert(confs.keySet.contains("compressionCodec")) + assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold")) + } + } + test("advanced implementation") { Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -210,6 +225,18 @@ class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport { configMap.put("spark.sql.parquet.compression.codec", "compressionCodec") configMap } + + override def getValidOptions: JList[String] = new util.ArrayList[String]() +} + +class DataSourceV2WithValidOptions extends DataSourceV2WithConfig { + + override def getValidOptions: JList[String] = { + java.util.Arrays.asList( + "sources.parallelPartitionDiscovery.threshold", + "compressionCodec", + "not.exist.option") + } } class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { From 8329a6be03a0426f9e9e53507766117bbd8efb71 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 7 Dec 2017 00:32:33 +0800 Subject: [PATCH 07/14] update comments. --- .../java/org/apache/spark/sql/sources/v2/ConfigSupport.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index cdc334bce273d..c9bdbfd9592b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -48,7 +48,7 @@ public interface ConfigSupport { * If the mapping is specified, for example, the returned map contains an entry * ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config * "spark.sql.columnNameOfCorruptRecord" will be converted to "colNameCorrupt" in - * [[DataSourceV2Options]]. + * [[org.apache.spark.sql.sources.v2.DataSourceV2Options]]. */ Map getConfigMapping(); From 6b4fcab95614f48d289f9f411cd33ba975569ea4 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 7 Dec 2017 22:05:40 +0800 Subject: [PATCH 08/14] update comments --- .../org/apache/spark/sql/sources/v2/ConfigSupport.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index c9bdbfd9592b8..1e764a25cb5be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -41,14 +41,14 @@ public interface ConfigSupport { * session config's key doesn't exist in this mapping, the "spark.sql.${source}" prefix will * be trimmed. For example, if the data source name is "parquet", perform the following config * key mapping by default: - * "spark.sql.parquet.int96AsTimestamp" -> "int96AsTimestamp", - * "spark.sql.parquet.compression.codec" -> "compression.codec", - * "spark.sql.columnNameOfCorruptRecord" -> "columnNameOfCorruptRecord". + * "spark.sql.parquet.int96AsTimestamp" -> "int96AsTimestamp", + * "spark.sql.parquet.compression.codec" -> "compression.codec", + * "spark.sql.columnNameOfCorruptRecord" -> "columnNameOfCorruptRecord". * * If the mapping is specified, for example, the returned map contains an entry - * ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config + * ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config * "spark.sql.columnNameOfCorruptRecord" will be converted to "colNameCorrupt" in - * [[org.apache.spark.sql.sources.v2.DataSourceV2Options]]. + * {@link DataSourceV2Options}. */ Map getConfigMapping(); From ec9a717d218e966f38068f1a407b749debea4f35 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 11 Dec 2017 22:54:34 +0800 Subject: [PATCH 09/14] simplify ConfigSupport interface. --- .../spark/sql/sources/v2/ConfigSupport.java | 32 ++-------- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../v2/DataSourceV2ConfigSupport.scala | 58 +++++------------- .../sql/sources/v2/DataSourceV2Suite.scala | 59 ++++--------------- 4 files changed, 32 insertions(+), 119 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java index 1e764a25cb5be..33ae1449afcbd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java @@ -30,33 +30,9 @@ public interface ConfigSupport { /** - * Create a list of key-prefixes, all session configs that match at least one of the prefixes - * will be propagated to the data source options. - * If the returned list is empty, no session config will be propagated. + * Name for the specified data source, will extract all session configs that starts with + * `spark.datasource.$name`, turn `spark.datasource.$name.xxx -> yyy` into + * `xxx -> yyy`, and propagate them to all data source operations in this session. */ - List getConfigPrefixes(); - - /** - * Create a mapping from session config names to data source option names. If a propagated - * session config's key doesn't exist in this mapping, the "spark.sql.${source}" prefix will - * be trimmed. For example, if the data source name is "parquet", perform the following config - * key mapping by default: - * "spark.sql.parquet.int96AsTimestamp" -> "int96AsTimestamp", - * "spark.sql.parquet.compression.codec" -> "compression.codec", - * "spark.sql.columnNameOfCorruptRecord" -> "columnNameOfCorruptRecord". - * - * If the mapping is specified, for example, the returned map contains an entry - * ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config - * "spark.sql.columnNameOfCorruptRecord" will be converted to "colNameCorrupt" in - * {@link DataSourceV2Options}. - */ - Map getConfigMapping(); - - /** - * Create a list of valid data source option names. When the list is specified, a session - * config will NOT be propagated if its corresponding option name is not in the list. - * - * If the returned list is empty, don't check the option names. - */ - List getValidOptions(); + String name(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 59204dc68b96a..351c7bf064e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -190,7 +190,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val dataSource = cls.newInstance() val options = dataSource match { case cs: ConfigSupport => - val confs = withSessionConfig(cs, source, sparkSession.sessionState.conf) + val confs = withSessionConfig(cs.name, sparkSession.sessionState.conf) new DataSourceV2Options((confs ++ extraOptions).asJava) case _ => new DataSourceV2Options(extraOptions.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala index cbea1a1c0f586..bbe1b7a4cab82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala @@ -29,57 +29,29 @@ import org.apache.spark.sql.sources.v2.ConfigSupport private[sql] object DataSourceV2ConfigSupport extends Logging { /** - * Helper method to propagate session configs with config key that matches at least one of the - * given prefixes to the corresponding data source options. + * Helper method that turns session configs with config keys that start with + * `spark.datasource.$name` into k/v pairs, the k/v pairs will be used to create data source + * options. + * A session config `spark.datasource.$name.xxx -> yyy` will be transformed into + * `xxx -> yyy`. * - * @param cs the session config propagate help class - * @param source the data source format + * @param name the data source name * @param conf the session conf - * @return an immutable map that contains all the session configs that should be propagated to - * the data source. + * @return an immutable map that contains all the extracted and transformed k/v pairs. */ def withSessionConfig( - cs: ConfigSupport, - source: String, + name: String, conf: SQLConf): immutable.Map[String, String] = { - val prefixes = cs.getConfigPrefixes - require(prefixes != null, "The config key-prefixes cann't be null.") - val mapping = cs.getConfigMapping.asScala - val validOptions = cs.getValidOptions - require(validOptions != null, "The valid options list cann't be null.") + require(name != null, "The data source name can't be null.") - val pattern = Pattern.compile(s"spark\\.sql(\\.$source)?\\.(.*)") + val pattern = Pattern.compile(s"spark\\.datasource\\.$name\\.(.*)") val filteredConfigs = conf.getAllConfs.filterKeys { confKey => - prefixes.asScala.exists(confKey.startsWith(_)) + confKey.startsWith(s"spark.datasource.$name") } - val convertedConfigs = filteredConfigs.map{ entry => - val newKey = mapping.get(entry._1).getOrElse { - val m = pattern.matcher(entry._1) - if (m.matches()) { - m.group(2) - } else { - // Unable to recognize the session config key. - logWarning(s"Unrecognizable session config name ${entry._1}.") - entry._1 - } - } - (newKey, entry._2) - } - if (validOptions.size == 0) { - convertedConfigs - } else { - // Check whether all the valid options are propagated. - validOptions.asScala.foreach { optionName => - if (!convertedConfigs.keySet.contains(optionName)) { - logWarning(s"Data source option '$optionName' is required, but not propagated from " + - "session config, please check the config settings.") - } - } - - // Filter the valid options. - convertedConfigs.filterKeys { optionName => - validOptions.contains(optionName) - } + filteredConfigs.map { entry => + val m = pattern.matcher(entry._1) + require(m.matches() && m.groupCount() > 0, s"Fail in matching ${entry._1} with $pattern.") + (m.group(1), entry._2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a7e82fecc5f45..49e14363a123f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources.v2 -import java.util import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ @@ -35,6 +34,8 @@ import org.apache.spark.sql.types.StructType class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private val dsName = "userDefinedDataSource" + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -47,32 +48,18 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("simple implementation with config support") { - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false", - SQLConf.PARQUET_COMPRESSION.key -> "uncompressed", - SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", - SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { + // Only match configs with keys start with "spark.datasource.${dsName}". + withSQLConf(s"spark.datasource.$dsName.foo.bar" -> "false", + s"spark.datasource.$dsName.whateverConfigName" -> "123", + s"spark.sql.$dsName.config.name" -> "false", + s"spark.datasource.another.config.name" -> "123") { val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport] - val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get) - assert(confs.size == 3) - assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0) - assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) - assert(confs.keySet.contains("compressionCodec")) - assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold")) - } - } - - test("config support with validOptions") { - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false", - SQLConf.PARQUET_COMPRESSION.key -> "uncompressed", - SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32", - SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") { - val cs = classOf[DataSourceV2WithValidOptions].newInstance().asInstanceOf[ConfigSupport] - val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get) + val confs = DataSourceV2ConfigSupport.withSessionConfig(cs.name, SQLConf.get) assert(confs.size == 2) - assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0) + assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) - assert(confs.keySet.contains("compressionCodec")) - assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold")) + assert(confs.keySet.contains("foo.bar")) + assert(confs.keySet.contains("whateverConfigName")) } } @@ -214,29 +201,7 @@ class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport { - override def getConfigPrefixes: JList[String] = { - java.util.Arrays.asList( - "spark.sql.parquet", - "spark.sql.sources.parallelPartitionDiscovery.threshold") - } - - override def getConfigMapping: util.Map[String, String] = { - val configMap = new util.HashMap[String, String]() - configMap.put("spark.sql.parquet.compression.codec", "compressionCodec") - configMap - } - - override def getValidOptions: JList[String] = new util.ArrayList[String]() -} - -class DataSourceV2WithValidOptions extends DataSourceV2WithConfig { - - override def getValidOptions: JList[String] = { - java.util.Arrays.asList( - "sources.parallelPartitionDiscovery.threshold", - "compressionCodec", - "not.exist.option") - } + override def name: String = "userDefinedDataSource" } class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { From ebb8d86b9e681ebb82a9d44d371fab285de09ec6 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 14 Dec 2017 15:57:17 +0800 Subject: [PATCH 10/14] refactor --- ...Support.java => SessionConfigSupport.java} | 9 ++--- .../apache/spark/sql/DataFrameReader.scala | 8 ++--- ...ourceV2ConfigSupport.scala => Utils.scala} | 35 +++++++++---------- .../sql/sources/v2/DataSourceV2Suite.scala | 11 +++--- 4 files changed, 31 insertions(+), 32 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ConfigSupport.java => SessionConfigSupport.java} (82%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/{DataSourceV2ConfigSupport.scala => Utils.scala} (57%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java similarity index 82% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 33ae1449afcbd..d130d7ab7afeb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -24,15 +24,16 @@ /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * propagate session configs with chosen key-prefixes to the particular data source. + * propagate session configs with the specified key-prefix to all data source operations in this + * session. */ @InterfaceStability.Evolving -public interface ConfigSupport { +public interface SessionConfigSupport { /** * Name for the specified data source, will extract all session configs that starts with - * `spark.datasource.$name`, turn `spark.datasource.$name.xxx -> yyy` into + * `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` into * `xxx -> yyy`, and propagate them to all data source operations in this session. */ - String name(); + String keyPrefix(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 351c7bf064e3e..4e66d417c91c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ConfigSupport import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.Utils import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -170,7 +170,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` } - import DataSourceV2ConfigSupport._ + import Utils._ /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. @@ -189,8 +189,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (classOf[DataSourceV2].isAssignableFrom(cls)) { val dataSource = cls.newInstance() val options = dataSource match { - case cs: ConfigSupport => - val confs = withSessionConfig(cs.name, sparkSession.sessionState.conf) + case cs: SessionConfigSupport => + val confs = withSessionConfig(cs.keyPrefix, sparkSession.sessionState.conf) new DataSourceV2Options((confs ++ extraOptions).asJava) case _ => new DataSourceV2Options(extraOptions.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala similarity index 57% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala index bbe1b7a4cab82..b5fa4c597ed74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala @@ -19,39 +19,36 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern -import scala.collection.JavaConverters._ -import scala.collection.immutable - import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.ConfigSupport -private[sql] object DataSourceV2ConfigSupport extends Logging { +private[sql] object Utils extends Logging { /** * Helper method that turns session configs with config keys that start with - * `spark.datasource.$name` into k/v pairs, the k/v pairs will be used to create data source + * `spark.datasource.$keyPrefix` into k/v pairs, the k/v pairs will be used to create data source * options. - * A session config `spark.datasource.$name.xxx -> yyy` will be transformed into + * A session config `spark.datasource.$keyPrefix.xxx -> yyy` will be transformed into * `xxx -> yyy`. * - * @param name the data source name + * @param keyPrefix the data source config key prefix to be matched * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ def withSessionConfig( - name: String, - conf: SQLConf): immutable.Map[String, String] = { - require(name != null, "The data source name can't be null.") + keyPrefix: String, + conf: SQLConf): Map[String, String] = { + require(keyPrefix != null, "The data source config key prefix can't be null.") - val pattern = Pattern.compile(s"spark\\.datasource\\.$name\\.(.*)") - val filteredConfigs = conf.getAllConfs.filterKeys { confKey => - confKey.startsWith(s"spark.datasource.$name") - } - filteredConfigs.map { entry => - val m = pattern.matcher(entry._1) - require(m.matches() && m.groupCount() > 0, s"Fail in matching ${entry._1} with $pattern.") - (m.group(1), entry._2) + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.*)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 49e14363a123f..03aa344d231f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ConfigSupport +import org.apache.spark.sql.execution.datasources.v2.Utils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -53,8 +53,9 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { s"spark.datasource.$dsName.whateverConfigName" -> "123", s"spark.sql.$dsName.config.name" -> "false", s"spark.datasource.another.config.name" -> "123") { - val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport] - val confs = DataSourceV2ConfigSupport.withSessionConfig(cs.name, SQLConf.get) + val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + .asInstanceOf[SessionConfigSupport] + val confs = Utils.withSessionConfig(cs.keyPrefix, SQLConf.get) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) @@ -199,9 +200,9 @@ class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader override def close(): Unit = {} } -class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport { +class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { - override def name: String = "userDefinedDataSource" + override def keyPrefix: String = "userDefinedDataSource" } class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { From 52aaf51a9ef0d3b2517ee26cff58d7f281433881 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 14 Dec 2017 21:27:25 +0800 Subject: [PATCH 11/14] refactor --- .../sql/sources/v2/SessionConfigSupport.java | 6 +-- .../apache/spark/sql/DataFrameReader.scala | 8 ++-- .../apache/spark/sql/DataFrameWriter.scala | 14 +++++- .../{Utils.scala => DataSourceV2Utils.scala} | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 24 ----------- .../sources/v2/DataSourceV2UtilsSuite.scala | 43 +++++++++++++++++++ 6 files changed, 63 insertions(+), 34 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/{Utils.scala => DataSourceV2Utils.scala} (97%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index d130d7ab7afeb..0b5b6ac675f2c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -31,9 +31,9 @@ public interface SessionConfigSupport { /** - * Name for the specified data source, will extract all session configs that starts with - * `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` into - * `xxx -> yyy`, and propagate them to all data source operations in this session. + * Key prefix of the session configs to propagate. Spark will extract all session configs that + * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` + * into `xxx -> yyy`, and propagate them to all data source operations in this session. */ String keyPrefix(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4e66d417c91c4..ad0cf5b275bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.datasources.v2.Utils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -170,8 +170,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` } - import Utils._ - /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. * Only works if the source is a HadoopFsRelationProvider. @@ -190,7 +188,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val dataSource = cls.newInstance() val options = dataSource match { case cs: SessionConfigSupport => - val confs = withSessionConfig(cs.keyPrefix, sparkSession.sessionState.conf) + val confs = DataSourceV2Utils.withSessionConfig( + keyPrefix = cs.keyPrefix, + conf = sparkSession.sessionState.conf) new DataSourceV2Options((confs ++ extraOptions).asJava) case _ => new DataSourceV2Options(extraOptions.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8d95b24c00619..091aa774f3861 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,9 +30,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType /** @@ -238,7 +239,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (classOf[DataSourceV2].isAssignableFrom(cls)) { cls.newInstance() match { case ds: WriteSupport => - val options = new DataSourceV2Options(extraOptions.asJava) + val dataSource = cls.newInstance() + val options = dataSource match { + case cs: SessionConfigSupport => + val confs = DataSourceV2Utils.withSessionConfig( + keyPrefix = cs.keyPrefix, + conf = df.sparkSession.sessionState.conf) + new DataSourceV2Options((confs ++ extraOptions).asJava) + case _ => + new DataSourceV2Options(extraOptions.asJava) + } // Using a timestamp and a random UUID to distinguish different writing jobs. This is good // enough as there won't be tons of writing jobs created at the same second. val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index b5fa4c597ed74..96ed6a63c22f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -22,7 +22,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -private[sql] object Utils extends Logging { +private[sql] object DataSourceV2Utils extends Logging { /** * Helper method that turns session configs with config keys that start with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 03aa344d231f0..ab37e4984bd1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,8 +24,6 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.Utils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext @@ -34,8 +32,6 @@ import org.apache.spark.sql.types.StructType class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ - private val dsName = "userDefinedDataSource" - test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -47,23 +43,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("simple implementation with config support") { - // Only match configs with keys start with "spark.datasource.${dsName}". - withSQLConf(s"spark.datasource.$dsName.foo.bar" -> "false", - s"spark.datasource.$dsName.whateverConfigName" -> "123", - s"spark.sql.$dsName.config.name" -> "false", - s"spark.datasource.another.config.name" -> "123") { - val cs = classOf[DataSourceV2WithSessionConfig].newInstance() - .asInstanceOf[SessionConfigSupport] - val confs = Utils.withSessionConfig(cs.keyPrefix, SQLConf.get) - assert(confs.size == 2) - assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) - assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) - assert(confs.keySet.contains("foo.bar")) - assert(confs.keySet.contains("whateverConfigName")) - } - } - test("advanced implementation") { Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -200,10 +179,7 @@ class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader override def close(): Unit = {} } -class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { - override def keyPrefix: String = "userDefinedDataSource" -} class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala new file mode 100644 index 0000000000000..0f0f07e8466c6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataSourceV2UtilsSuite extends QueryTest with SharedSQLContext { + + private val keyPrefix = "userDefinedDataSource" + + test("method withSessionConfig() should propagate session configs correctly") { + // Only match configs with keys start with "spark.datasource.${keyPrefix}". + withSQLConf(s"spark.datasource.$keyPrefix.foo.bar" -> "false", + s"spark.datasource.$keyPrefix.whateverConfigName" -> "123", + s"spark.sql.$keyPrefix.config.name" -> "false", + s"spark.datasource.another.config.name" -> "123") { + val confs = DataSourceV2Utils.withSessionConfig(keyPrefix, SQLConf.get) + assert(confs.size == 2) + assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) + assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + assert(confs.keySet.contains("foo.bar")) + assert(confs.keySet.contains("whateverConfigName")) + } + } +} From d964158c4c18053b95ccd6ea4a20a30cc0db3233 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 15 Dec 2017 01:07:47 +0800 Subject: [PATCH 12/14] refactor DataSourceV2Utils --- .../apache/spark/sql/DataFrameReader.scala | 17 +++---- .../apache/spark/sql/DataFrameWriter.scala | 21 ++++----- .../datasources/v2/DataSourceV2Utils.scala | 44 ++++++++++--------- .../sources/v2/DataSourceV2UtilsSuite.scala | 35 +++++++++------ 4 files changed, 59 insertions(+), 58 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ad0cf5b275bd5..7db28c5ccd74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -185,18 +185,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() - val options = dataSource match { - case cs: SessionConfigSupport => - val confs = DataSourceV2Utils.withSessionConfig( - keyPrefix = cs.keyPrefix, - conf = sparkSession.sessionState.conf) - new DataSourceV2Options((confs ++ extraOptions).asJava) - case _ => - new DataSourceV2Options(extraOptions.asJava) - } + val ds = cls.newInstance() + val options = new DataSourceV2Options((extraOptions ++ + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) - val reader = (dataSource, userSpecifiedSchema) match { + val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 091aa774f3861..0a0517a5b67b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -237,23 +237,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - cls.newInstance() match { - case ds: WriteSupport => - val dataSource = cls.newInstance() - val options = dataSource match { - case cs: SessionConfigSupport => - val confs = DataSourceV2Utils.withSessionConfig( - keyPrefix = cs.keyPrefix, - conf = df.sparkSession.sessionState.conf) - new DataSourceV2Options((confs ++ extraOptions).asJava) - case _ => - new DataSourceV2Options(extraOptions.asJava) - } + val ds = cls.newInstance() + ds match { + case ws: WriteSupport => + val options = new DataSourceV2Options((extraOptions ++ + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = df.sparkSession.sessionState.conf)).asJava) // Using a timestamp and a random UUID to distinguish different writing jobs. This is good // enough as there won't be tons of writing jobs created at the same second. val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) .format(new Date()) + "-" + UUID.randomUUID() - val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { WriteToDataSourceV2(writer.get(), df.logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 96ed6a63c22f7..4794a8a18d2d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,34 +21,38 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { /** - * Helper method that turns session configs with config keys that start with - * `spark.datasource.$keyPrefix` into k/v pairs, the k/v pairs will be used to create data source - * options. - * A session config `spark.datasource.$keyPrefix.xxx -> yyy` will be transformed into - * `xxx -> yyy`. + * Helper method that extracts and transforms session configs into k/v pairs, the k/v pairs will + * be used to create data source options. + * Only extract when `ds` implements [[SessionConfigSupport]], in this case we may fetch the + * specified key-prefix from `ds`, and extract session configs with config keys that start with + * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will + * be transformed into `xxx -> yyy`. * - * @param keyPrefix the data source config key prefix to be matched + * @param ds a [[DataSourceV2]] object * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ - def withSessionConfig( - keyPrefix: String, - conf: SQLConf): Map[String, String] = { - require(keyPrefix != null, "The data source config key prefix can't be null.") - - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.*)") - - conf.getAllConfs.flatMap { case (key, value) => - val m = pattern.matcher(key) - if (m.matches() && m.groupCount() > 0) { - Seq((m.group(1), value)) - } else { - Seq.empty + def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.*)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } - } + + case _ => Map.empty } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 0f0f07e8466c6..5d971892d60d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -17,27 +17,34 @@ package org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -class DataSourceV2UtilsSuite extends QueryTest with SharedSQLContext { +class DataSourceV2UtilsSuite extends QueryTest { + + protected var spark: SparkSession = null private val keyPrefix = "userDefinedDataSource" test("method withSessionConfig() should propagate session configs correctly") { // Only match configs with keys start with "spark.datasource.${keyPrefix}". - withSQLConf(s"spark.datasource.$keyPrefix.foo.bar" -> "false", - s"spark.datasource.$keyPrefix.whateverConfigName" -> "123", - s"spark.sql.$keyPrefix.config.name" -> "false", - s"spark.datasource.another.config.name" -> "123") { - val confs = DataSourceV2Utils.withSessionConfig(keyPrefix, SQLConf.get) - assert(confs.size == 2) - assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) - assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) - assert(confs.keySet.contains("foo.bar")) - assert(confs.keySet.contains("whateverConfigName")) - } + val conf = new SQLConf + conf.setConfString(s"spark.datasource.$keyPrefix.foo.bar", "false") + conf.setConfString(s"spark.datasource.$keyPrefix.whateverConfigName", "123") + conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") + conf.setConfString("spark.datasource.another.config.name", "123") + val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + assert(confs.size == 2) + assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) + assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + assert(confs.keySet.contains("foo.bar")) + assert(confs.keySet.contains("whateverConfigName")) } } + +class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { + + override def keyPrefix: String = "userDefinedDataSource" +} From f7d5a4dfce26f2d8d79f8b2529b9676fdf03c917 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 15 Dec 2017 12:53:38 +0800 Subject: [PATCH 13/14] update DataSourceV2UtilsSuite --- .../spark/sql/sources/v2/DataSourceV2UtilsSuite.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 5d971892d60d0..42fa5a2a0282f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.{QueryTest, SparkSession} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.SQLConf -class DataSourceV2UtilsSuite extends QueryTest { +class DataSourceV2UtilsSuite extends SparkFunSuite { - protected var spark: SparkSession = null - - private val keyPrefix = "userDefinedDataSource" + private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix test("method withSessionConfig() should propagate session configs correctly") { // Only match configs with keys start with "spark.datasource.${keyPrefix}". From 52923296a946ac734c988fe10725921ea3c2b313 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 15 Dec 2017 17:39:52 +0800 Subject: [PATCH 14/14] update regex pattern --- .../spark/sql/execution/datasources/v2/DataSourceV2Utils.scala | 2 +- .../apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 4794a8a18d2d1..5267f5f1580c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -42,7 +42,7 @@ private[sql] object DataSourceV2Utils extends Logging { val keyPrefix = cs.keyPrefix() require(keyPrefix != null, "The data source config key prefix can't be null.") - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.*)") + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") conf.getAllConfs.flatMap { case (key, value) => val m = pattern.matcher(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 42fa5a2a0282f..4911e3225552d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -32,6 +32,7 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.datasource.$keyPrefix.whateverConfigName", "123") conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") + conf.setConfString(s"spark.datasource.$keyPrefix.", "123") val cs = classOf[DataSourceV2WithSessionConfig].newInstance() val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) assert(confs.size == 2)