diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 4ffa70f9f31d..a5e5d01152db 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBat
import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.streaming.{Sink, Source}
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@@ -51,7 +52,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with StreamSinkProvider
with RelationProvider
with CreatableRelationProvider
- with TableProvider
+ with SimpleTableProvider
with Logging {
import KafkaSourceProvider._
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java
index e9fd87d0e2d4..732c5352a15a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java
@@ -17,7 +17,10 @@
package org.apache.spark.sql.connector.catalog;
+import java.util.Map;
+
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
@@ -36,26 +39,50 @@
public interface TableProvider {
/**
- * Return a {@link Table} instance to do read/write with user-specified options.
+ * Infer the schema of the table identified by the given options.
+ *
+ * @param options an immutable case-insensitive string-to-string map that can identify a table,
+ * e.g. file path, Kafka topic name, etc.
+ */
+ StructType inferSchema(CaseInsensitiveStringMap options);
+
+ /**
+ * Infer the partitioning of the table identified by the given options.
+ *
+ * By default this method returns empty partitioning, please override it if this source support
+ * partitioning.
+ *
+ * @param options an immutable case-insensitive string-to-string map that can identify a table,
+ * e.g. file path, Kafka topic name, etc.
+ */
+ default Transform[] inferPartitioning(CaseInsensitiveStringMap options) {
+ return new Transform[0];
+ }
+
+ /**
+ * Return a {@link Table} instance with the specified table schema, partitioning and properties
+ * to do read/write. The returned table should report the same schema and partitioning with the
+ * specified ones, or Spark may fail the operation.
*
- * @param options the user-specified options that can identify a table, e.g. file path, Kafka
- * topic name, etc. It's an immutable case-insensitive string-to-string map.
+ * @param schema The specified table schema.
+ * @param partitioning The specified table partitioning.
+ * @param properties The specified table properties. It's case preserving (contains exactly what
+ * users specified) and implementations are free to use it case sensitively or
+ * insensitively. It should be able to identify a table, e.g. file path, Kafka
+ * topic name, etc.
*/
- Table getTable(CaseInsensitiveStringMap options);
+ Table getTable(StructType schema, Transform[] partitioning, Map properties);
/**
- * Return a {@link Table} instance to do read/write with user-specified schema and options.
+ * Returns true if the source has the ability of accepting external table metadata when getting
+ * tables. The external table metadata includes user-specified schema from
+ * `DataFrameReader`/`DataStreamReader` and schema/partitioning stored in Spark catalog.
*
- * By default this method throws {@link UnsupportedOperationException}, implementations should
- * override this method to handle user-specified schema.
- *
- * @param options the user-specified options that can identify a table, e.g. file path, Kafka
- * topic name, etc. It's an immutable case-insensitive string-to-string map.
- * @param schema the user-specified schema.
- * @throws UnsupportedOperationException
+ * By default this method returns false, which means the schema and partitioning passed to
+ * `getTable` are from the infer methods. Please override it if this source has expensive
+ * schema/partitioning inference and wants external table metadata to avoid inference.
*/
- default Table getTable(CaseInsensitiveStringMap options, StructType schema) {
- throw new UnsupportedOperationException(
- this.getClass().getSimpleName() + " source does not support user-specified schema");
+ default boolean supportsExternalMetadata() {
+ return false;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index 16aec23521f9..3478af8783af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -21,7 +21,6 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
-import org.apache.spark.sql.types.StructType
/**
* Conversion helpers for working with v2 [[CatalogPlugin]].
@@ -29,9 +28,9 @@ import org.apache.spark.sql.types.StructType
private[sql] object CatalogV2Implicits {
import LogicalExpressions._
- implicit class PartitionTypeHelper(partitionType: StructType) {
+ implicit class PartitionTypeHelper(colNames: Seq[String]) {
def asTransforms: Array[Transform] = {
- partitionType.names.map(col => identity(reference(Seq(col)))).toArray
+ colNames.map(col => identity(reference(Seq(col)))).toArray
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala
new file mode 100644
index 000000000000..7bfe1df1117a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import java.util
+
+import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+// A simple version of `TableProvider` which doesn't support specified table schema/partitioning
+// and treats table properties case-insensitively. This is private and only used in builtin sources.
+trait SimpleTableProvider extends TableProvider {
+
+ def getTable(options: CaseInsensitiveStringMap): Table
+
+ private[this] var loadedTable: Table = _
+ private def getOrLoadTable(options: CaseInsensitiveStringMap): Table = {
+ if (loadedTable == null) loadedTable = getTable(options)
+ loadedTable
+ }
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ getOrLoadTable(options).schema()
+ }
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ assert(partitioning.isEmpty)
+ getOrLoadTable(new CaseInsensitiveStringMap(properties))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index b5d7bbca9064..6cce7203127f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -220,10 +220,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
(catalog.loadTable(ident), Some(catalog), Some(ident))
case _ =>
// TODO: Non-catalog paths for DSV2 are currently not well defined.
- userSpecifiedSchema match {
- case Some(schema) => (provider.getTable(dsOptions, schema), None, None)
- case _ => (provider.getTable(dsOptions), None, None)
- }
+ val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
+ (tbl, None, None)
}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index c041d14c8b8d..4557219abeb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -257,6 +257,21 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val options = sessionOptions ++ extraOptions
val dsOptions = new CaseInsensitiveStringMap(options.asJava)
+ def getTable: Table = {
+ // For file source, it's expensive to infer schema/partition at each write. Here we pass
+ // the schema of input query and the user-specified partitioning to `getTable`. If the
+ // query schema is not compatible with the existing data, the write can still success but
+ // following reads would fail.
+ if (provider.isInstanceOf[FileDataSourceV2]) {
+ provider.getTable(
+ df.schema.asNullable,
+ partitioningAsV2.toArray,
+ dsOptions.asCaseSensitiveMap())
+ } else {
+ DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema = None)
+ }
+ }
+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val catalogManager = df.sparkSession.sessionState.catalogManager
mode match {
@@ -268,8 +283,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
supportsExtract, catalogManager, dsOptions)
(catalog.loadTable(ident), Some(catalog), Some(ident))
- case tableProvider: TableProvider =>
- val t = tableProvider.getTable(dsOptions)
+ case _: TableProvider =>
+ val t = getTable
if (t.supports(BATCH_WRITE)) {
(t, None, None)
} else {
@@ -314,8 +329,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
extraOptions.toMap,
ignoreIfExists = createMode == SaveMode.Ignore)
}
- case tableProvider: TableProvider =>
- if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) {
+ case _: TableProvider =>
+ if (getTable.supports(BATCH_WRITE)) {
throw new AnalysisException(s"TableProvider implementation $source cannot be " +
s"written with $createMode mode, please use Append or Overwrite " +
"modes instead.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
index b6149ce7290b..4fad0a2484cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
@@ -22,9 +22,10 @@ import java.util
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -33,7 +34,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* This is no-op datasource. It does not do anything besides consuming its input.
* This can be useful for benchmarking or to cache data without any additional overhead.
*/
-class NoopDataSource extends TableProvider with DataSourceRegister {
+class NoopDataSource extends SimpleTableProvider with DataSourceRegister {
override def shortName(): String = "noop"
override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
index 52294ae2cb85..b50b8295463e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.regex.Pattern
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
private[sql] object DataSourceV2Utils extends Logging {
@@ -57,4 +59,28 @@ private[sql] object DataSourceV2Utils extends Logging {
case _ => Map.empty
}
}
+
+ def getTableFromProvider(
+ provider: TableProvider,
+ options: CaseInsensitiveStringMap,
+ userSpecifiedSchema: Option[StructType]): Table = {
+ userSpecifiedSchema match {
+ case Some(schema) =>
+ if (provider.supportsExternalMetadata()) {
+ provider.getTable(
+ schema,
+ provider.inferPartitioning(options),
+ options.asCaseSensitiveMap())
+ } else {
+ throw new UnsupportedOperationException(
+ s"${provider.getClass.getSimpleName} source does not support user-specified schema.")
+ }
+
+ case None =>
+ provider.getTable(
+ provider.inferSchema(options),
+ provider.inferPartitioning(options),
+ options.asCaseSensitiveMap())
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
index e0091293d166..30a964d7e643 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
@@ -16,13 +16,17 @@
*/
package org.apache.spark.sql.execution.datasources.v2
+import java.util
+
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.catalog.TableProvider
+import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
+import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
@@ -59,4 +63,40 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister {
val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString
}
+
+ // TODO: To reduce code diff of SPARK-29665, we create stub implementations for file source v2, so
+ // that we don't need to touch all the file source v2 classes. We should remove the stub
+ // implementation and directly implement the TableProvider APIs.
+ protected def getTable(options: CaseInsensitiveStringMap): Table
+ protected def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
+ throw new UnsupportedOperationException("user-specified schema")
+ }
+
+ override def supportsExternalMetadata(): Boolean = true
+
+ private var t: Table = null
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ if (t == null) t = getTable(options)
+ t.schema()
+ }
+
+ // TODO: implement a light-weight partition inference which only looks at the path of one leaf
+ // file and return partition column names. For now the partition inference happens in
+ // `getTable`, because we don't know the user-specified schema here.
+ override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = {
+ Array.empty
+ }
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ // If the table is already loaded during schema inference, return it directly.
+ if (t != null) {
+ t
+ } else {
+ getTable(new CaseInsensitiveStringMap(properties), schema)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
index 5329e09916bd..59dc3ae56bf2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
@@ -102,7 +102,7 @@ abstract class FileTable(
StructType(fields)
}
- override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms
+ override def partitioning: Array[Transform] = fileIndex.partitionSchema.names.toSeq.asTransforms
override def properties: util.Map[String, String] = options.asCaseSensitiveMap
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 63e40891942a..e471e6c601d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -22,10 +22,11 @@ import java.util
import scala.collection.JavaConverters._
import org.apache.spark.sql._
-import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsTruncate, WriteBuilder}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -35,7 +36,7 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
override def schema: StructType = data.schema
}
-class ConsoleSinkProvider extends TableProvider
+class ConsoleSinkProvider extends SimpleTableProvider
with DataSourceRegister
with CreatableRelationProvider {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 911a526428cf..395811b72d32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -31,10 +31,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -94,7 +95,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa
// This class is used to indicate the memory stream data source. We don't actually use it, as
// memory stream is for test only and we never look it up by name.
-object MemoryStreamTableProvider extends TableProvider {
+object MemoryStreamTableProvider extends SimpleTableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new IllegalStateException("MemoryStreamTableProvider should not be used.")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
index 3f7b0377f1ea..a093bf54b210 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -23,10 +23,11 @@ import scala.collection.JavaConverters._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -45,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
-class RateStreamProvider extends TableProvider with DataSourceRegister {
+class RateStreamProvider extends SimpleTableProvider with DataSourceRegister {
import RateStreamProvider._
override def getTable(options: CaseInsensitiveStringMap): Table = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
index fae3cb765c0c..a4dcb2049eb8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
@@ -26,15 +26,16 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging {
+class TextSocketSourceProvider extends SimpleTableProvider with DataSourceRegister with Logging {
private def checkParameters(params: CaseInsensitiveStringMap): Unit = {
logWarning("The socket source should not be used for production applications! " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index cfe6192e7d5c..0eb4776988d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.StructType
@@ -173,15 +173,13 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
case _ => None
}
ds match {
- case provider: TableProvider =>
+ // file source v2 does not support streaming yet.
+ case provider: TableProvider if !provider.isInstanceOf[FileDataSourceV2] =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val dsOptions = new CaseInsensitiveStringMap(options.asJava)
- val table = userSpecifiedSchema match {
- case Some(schema) => provider.getTable(dsOptions, schema)
- case _ => provider.getTable(dsOptions)
- }
+ val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 62a1add8b6d9..1c21a30dd5bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -308,7 +308,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
} else {
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
- val useV1Source = disabledSources.contains(cls.getCanonicalName)
+ val useV1Source = disabledSources.contains(cls.getCanonicalName) ||
+ // file source v2 does not support streaming yet.
+ classOf[FileDataSourceV2].isAssignableFrom(cls)
val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) {
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
@@ -316,8 +318,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
source = provider, conf = df.sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val dsOptions = new CaseInsensitiveStringMap(options.asJava)
+ val table = DataSourceV2Utils.getTableFromProvider(
+ provider, dsOptions, userSpecifiedSchema = None)
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
- provider.getTable(dsOptions) match {
+ table match {
case table: SupportsWrite if table.supports(STREAMING_WRITE) =>
table
case _ => createV1Sink()
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java
index 9386ab51d64f..1a55d198361e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java
@@ -22,15 +22,15 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.read.*;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-public class JavaAdvancedDataSourceV2 implements TableProvider {
+public class JavaAdvancedDataSourceV2 implements TestingV2Source {
@Override
public Table getTable(CaseInsensitiveStringMap options) {
@@ -45,7 +45,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
static class AdvancedScanBuilder implements ScanBuilder, Scan,
SupportsPushDownFilters, SupportsPushDownRequiredColumns {
- private StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
+ private StructType requiredSchema = TestingV2Source.schema();
private Filter[] filters = new Filter[0];
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java
index 76da45e182b3..2f10c84c999f 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java
@@ -20,8 +20,8 @@
import java.io.IOException;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
@@ -33,7 +33,7 @@
import org.apache.spark.sql.vectorized.ColumnarBatch;
-public class JavaColumnarDataSourceV2 implements TableProvider {
+public class JavaColumnarDataSourceV2 implements TestingV2Source {
class MyScanBuilder extends JavaSimpleScanBuilder {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java
index fbbc457b2945..9c1db7a37960 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java
@@ -22,17 +22,17 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.read.*;
import org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution;
import org.apache.spark.sql.connector.read.partitioning.Distribution;
import org.apache.spark.sql.connector.read.partitioning.Partitioning;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-public class JavaPartitionAwareDataSource implements TableProvider {
+public class JavaPartitionAwareDataSource implements TestingV2Source {
class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportPartitioning {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java
index 49438fe668d5..9a787c3d2d92 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java
@@ -19,15 +19,15 @@
import java.util.OptionalLong;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsReportStatistics;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-public class JavaReportStatisticsDataSource implements TableProvider {
+public class JavaReportStatisticsDataSource implements TestingV2Source {
class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics {
@Override
public Statistics estimateStatistics() {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java
index 2181887ae54e..5f73567ade02 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java
@@ -17,8 +17,11 @@
package test.org.apache.spark.sql.connector;
+import java.util.Map;
+
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableProvider;
+import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.types.StructType;
@@ -46,7 +49,18 @@ public InputPartition[] planInputPartitions() {
}
@Override
- public Table getTable(CaseInsensitiveStringMap options, StructType schema) {
+ public boolean supportsExternalMetadata() {
+ return true;
+ }
+
+ @Override
+ public StructType inferSchema(CaseInsensitiveStringMap options) {
+ throw new IllegalArgumentException("requires a user-supplied schema");
+ }
+
+ @Override
+ public Table getTable(
+ StructType schema, Transform[] partitioning, Map properties) {
return new JavaSimpleBatchTable() {
@Override
@@ -60,9 +74,4 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
}
};
}
-
- @Override
- public Table getTable(CaseInsensitiveStringMap options) {
- throw new IllegalArgumentException("requires a user-supplied schema");
- }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java
index 97b00477e176..71cf97b56fe5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java
@@ -21,6 +21,7 @@
import java.util.HashSet;
import java.util.Set;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
@@ -34,7 +35,7 @@ abstract class JavaSimpleBatchTable implements Table, SupportsRead {
@Override
public StructType schema() {
- return new StructType().add("i", "int").add("j", "int");
+ return TestingV2Source.schema();
}
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java
index 8b6d71b986ff..8852249d8a01 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java
@@ -17,13 +17,13 @@
package test.org.apache.spark.sql.connector;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-public class JavaSimpleDataSourceV2 implements TableProvider {
+public class JavaSimpleDataSourceV2 implements TestingV2Source {
class MyScanBuilder extends JavaSimpleScanBuilder {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java
index 7cbba0042092..bdd9dd3ea0ce 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java
@@ -17,6 +17,7 @@
package test.org.apache.spark.sql.connector;
+import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.connector.read.Scan;
@@ -37,7 +38,7 @@ public Batch toBatch() {
@Override
public StructType readSchema() {
- return new StructType().add("i", "int").add("j", "int");
+ return TestingV2Source.schema();
}
@Override
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 08627e681f9e..4c67888cbdc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -92,12 +92,6 @@ class DataSourceV2DataFrameSessionCatalogSuite
}
}
-class InMemoryTableProvider extends TableProvider {
- override def getTable(options: CaseInsensitiveStringMap): Table = {
- throw new UnsupportedOperationException("D'oh!")
- }
-}
-
class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] {
override def newTable(
name: String,
@@ -140,7 +134,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
spark.sessionState.catalogManager.catalog(name)
}
- protected val v2Format: String = classOf[InMemoryTableProvider].getName
+ protected val v2Format: String = classOf[FakeV2Provider].getName
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 04e5a8dfd78b..2c8349a0e6a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -2230,7 +2231,7 @@ class DataSourceV2SQLSuite
/** Used as a V2 DataSource for V2SessionCatalog DDL */
-class FakeV2Provider extends TableProvider {
+class FakeV2Provider extends SimpleTableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new UnsupportedOperationException("Unnecessary for DDL tests")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 85ff86ef3fc5..2d8761f872da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
+import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -418,7 +419,7 @@ object SimpleReaderFactory extends PartitionReaderFactory {
abstract class SimpleBatchTable extends Table with SupportsRead {
- override def schema(): StructType = new StructType().add("i", "int").add("j", "int")
+ override def schema(): StructType = TestingV2Source.schema
override def name(): String = this.getClass.toString
@@ -432,12 +433,31 @@ abstract class SimpleScanBuilder extends ScanBuilder
override def toBatch: Batch = this
- override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+ override def readSchema(): StructType = TestingV2Source.schema
override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory
}
-class SimpleSinglePartitionSource extends TableProvider {
+trait TestingV2Source extends TableProvider {
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ TestingV2Source.schema
+ }
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ getTable(new CaseInsensitiveStringMap(properties))
+ }
+
+ def getTable(options: CaseInsensitiveStringMap): Table
+}
+
+object TestingV2Source {
+ val schema = new StructType().add("i", "int").add("j", "int")
+}
+
+class SimpleSinglePartitionSource extends TestingV2Source {
class MyScanBuilder extends SimpleScanBuilder {
override def planInputPartitions(): Array[InputPartition] = {
@@ -452,9 +472,10 @@ class SimpleSinglePartitionSource extends TableProvider {
}
}
+
// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
-class SimpleDataSourceV2 extends TableProvider {
+class SimpleDataSourceV2 extends TestingV2Source {
class MyScanBuilder extends SimpleScanBuilder {
override def planInputPartitions(): Array[InputPartition] = {
@@ -469,7 +490,7 @@ class SimpleDataSourceV2 extends TableProvider {
}
}
-class AdvancedDataSourceV2 extends TableProvider {
+class AdvancedDataSourceV2 extends TestingV2Source {
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
@@ -481,7 +502,7 @@ class AdvancedDataSourceV2 extends TableProvider {
class AdvancedScanBuilder extends ScanBuilder
with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
- var requiredSchema = new StructType().add("i", "int").add("j", "int")
+ var requiredSchema = TestingV2Source.schema
var filters = Array.empty[Filter]
override def pruneColumns(requiredSchema: StructType): Unit = {
@@ -567,11 +588,16 @@ class SchemaRequiredDataSource extends TableProvider {
override def readSchema(): StructType = schema
}
- override def getTable(options: CaseInsensitiveStringMap): Table = {
+ override def supportsExternalMetadata(): Boolean = true
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
throw new IllegalArgumentException("requires a user-supplied schema")
}
- override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]): Table = {
val userGivenSchema = schema
new SimpleBatchTable {
override def schema(): StructType = userGivenSchema
@@ -583,7 +609,7 @@ class SchemaRequiredDataSource extends TableProvider {
}
}
-class ColumnarDataSourceV2 extends TableProvider {
+class ColumnarDataSourceV2 extends TestingV2Source {
class MyScanBuilder extends SimpleScanBuilder {
@@ -648,7 +674,7 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
}
}
-class PartitionAwareDataSource extends TableProvider {
+class PartitionAwareDataSource extends TestingV2Source {
class MyScanBuilder extends SimpleScanBuilder
with SupportsReportPartitioning{
@@ -716,7 +742,7 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
}
}
-class ReportStatisticsDataSource extends TableProvider {
+class ReportStatisticsDataSource extends SimpleWritableDataSource {
class MyScanBuilder extends SimpleScanBuilder
with SupportsReportStatistics {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
index 0070076459f1..f9306ba28e7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
@@ -27,10 +27,11 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsWrite, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, ScanBuilder}
import org.apache.spark.sql.connector.write._
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
@@ -40,7 +41,7 @@ import org.apache.spark.util.SerializableConfiguration
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
*/
-class SimpleWritableDataSource extends TableProvider with SessionConfigSupport {
+class SimpleWritableDataSource extends SimpleTableProvider with SessionConfigSupport {
private val tableSchema = new StructType().add("i", "long").add("j", "long")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index cec48bb368ae..7bff955b1836 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import scala.language.implicitConversions
import scala.util.Try
@@ -275,7 +273,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}
class CatalogSupportingInMemoryTableProvider
- extends InMemoryTableProvider
+ extends FakeV2Provider
with SupportsCatalogOptions {
override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
index 5196ca65276e..23e4c293cbc2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
@@ -40,7 +40,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
private val emptyMap = CaseInsensitiveStringMap.empty
private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = {
StreamingRelationV2(
- TestTableProvider,
+ new FakeV2Provider,
"fake",
table,
CaseInsensitiveStringMap.empty(),
@@ -211,12 +211,6 @@ private case object TestRelation extends LeafNode with NamedRelation {
override def output: Seq[AttributeReference] = TableCapabilityCheckSuite.schema.toAttributes
}
-private object TestTableProvider extends TableProvider {
- override def getTable(options: CaseInsensitiveStringMap): Table = {
- throw new UnsupportedOperationException
- }
-}
-
private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
override def name(): String = "capability_test_table"
override def schema(): StructType = TableCapabilityCheckSuite.schema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
index 8e2c63417b37..74f2ca14234d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table,
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, V1Scan}
import org.apache.spark.sql.execution.RowDataSourceScanExec
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.{BaseRelation, Filter, GreaterThan, TableScan}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -120,7 +121,7 @@ object V1ReadFallbackCatalog {
val schema = new StructType().add("i", "int").add("j", "int")
}
-class V1ReadFallbackTableProvider extends TableProvider {
+class V1ReadFallbackTableProvider extends SimpleTableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
new TableWithV1ReadFallback("v1-read-fallback")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index a36e8dbdec50..10ed2048dbf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -25,10 +25,11 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
-import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -173,7 +174,7 @@ private object InMemoryV1Provider {
}
class InMemoryV1Provider
- extends TableProvider
+ extends SimpleTableProvider
with DataSourceRegister
with CreatableRelationProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index 70b9b7ec12ea..30b7e93a4beb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat,
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable}
-import org.apache.spark.sql.connector.InMemoryTableProvider
+import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table}
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -41,7 +41,7 @@ import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, Integ
class PlanResolutionSuite extends AnalysisTest {
import CatalystSqlParser._
- private val v2Format = classOf[InMemoryTableProvider].getName
+ private val v2Format = classOf[FakeV2Provider].getName
private val table: Table = {
val t = mock(classOf[Table])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
index 0f80e2d431bb..5c66fc52592b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -194,13 +194,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession {
}
test("user-specified schema given") {
- val provider = new TextSocketSourceProvider
val userSpecifiedSchema = StructType(
StructField("name", StringType) ::
StructField("area", StringType) :: Nil)
val params = Map("host" -> "localhost", "port" -> "1234")
val exception = intercept[UnsupportedOperationException] {
- provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema)
+ spark.readStream.schema(userSpecifiedSchema).format("socket").options(params).load()
}
assert(exception.getMessage.contains(
"TextSocketSourceProvider source does not support user-specified schema"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
index 13bc811a8fe9..05cf324f8d49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactor
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{ContinuousTrigger, RateStreamOffset, Sink, StreamingQueryWrapper}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger}
import org.apache.spark.sql.types.StructType
@@ -93,7 +94,7 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite {
class FakeReadMicroBatchOnly
extends DataSourceRegister
- with TableProvider
+ with SimpleTableProvider
with SessionConfigSupport {
override def shortName(): String = "fake-read-microbatch-only"
@@ -116,7 +117,7 @@ class FakeReadMicroBatchOnly
class FakeReadContinuousOnly
extends DataSourceRegister
- with TableProvider
+ with SimpleTableProvider
with SessionConfigSupport {
override def shortName(): String = "fake-read-continuous-only"
@@ -137,7 +138,7 @@ class FakeReadContinuousOnly
}
}
-class FakeReadBothModes extends DataSourceRegister with TableProvider {
+class FakeReadBothModes extends DataSourceRegister with SimpleTableProvider {
override def shortName(): String = "fake-read-microbatch-continuous"
override def getTable(options: CaseInsensitiveStringMap): Table = {
@@ -154,7 +155,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider {
}
}
-class FakeReadNeitherMode extends DataSourceRegister with TableProvider {
+class FakeReadNeitherMode extends DataSourceRegister with SimpleTableProvider {
override def shortName(): String = "fake-read-neither-mode"
override def getTable(options: CaseInsensitiveStringMap): Table = {
@@ -168,7 +169,7 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider {
class FakeWriteOnly
extends DataSourceRegister
- with TableProvider
+ with SimpleTableProvider
with SessionConfigSupport {
override def shortName(): String = "fake-write-microbatch-continuous"
@@ -183,7 +184,7 @@ class FakeWriteOnly
}
}
-class FakeNoWrite extends DataSourceRegister with TableProvider {
+class FakeNoWrite extends DataSourceRegister with SimpleTableProvider {
override def shortName(): String = "fake-write-neither-mode"
override def getTable(options: CaseInsensitiveStringMap): Table = {
new Table {
@@ -201,7 +202,7 @@ class FakeSink extends Sink {
}
class FakeWriteSupportProviderV1Fallback extends DataSourceRegister
- with TableProvider with StreamSinkProvider {
+ with SimpleTableProvider with StreamSinkProvider {
override def createSink(
sqlContext: SQLContext,
@@ -378,10 +379,10 @@ class StreamingDataSourceV2Suite extends StreamTest {
for ((read, write, trigger) <- cases) {
testQuietly(s"stream with read format $read, write format $write, trigger $trigger") {
val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor()
- .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty())
+ .newInstance().asInstanceOf[SimpleTableProvider].getTable(CaseInsensitiveStringMap.empty())
val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor()
- .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty())
+ .newInstance().asInstanceOf[SimpleTableProvider].getTable(CaseInsensitiveStringMap.empty())
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
trigger match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala
index f25758c52069..c594a8523d15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala
@@ -25,11 +25,12 @@ import scala.collection.JavaConverters._
import org.apache.zookeeper.KeeperException.UnimplementedException
import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability.CONTINUOUS_READ
import org.apache.spark.sql.connector.read.{streaming, InputPartition, Scan, ScanBuilder}
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReaderFactory, ContinuousStream, PartitionOffset}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Source}
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -54,7 +55,7 @@ object BlockOnStopSourceProvider {
}
}
-class BlockOnStopSourceProvider extends StreamSourceProvider with TableProvider {
+class BlockOnStopSourceProvider extends StreamSourceProvider with SimpleTableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
new BlockOnStopSourceTable(BlockOnStopSourceProvider._latch)
}